annotate: dedup task_aug + row-normalization; docs module on/off table

Two behavior-preserving simplifications:
  * plan_subtasks_memory.run_episode: the task_aug 'axes' and free-form
    branches built identical deduped rows via copy-pasted seen/append
    loops. Collapse to one branch that picks the variant source, then a
    shared _task_aug_rows() helper does the dedup + row build (-~25 LOC).
  * writer: _normalize_persistent_row / _normalize_event_row shared the
    same camera-validate + struct construction. Extract _normalize_row(),
    keeping the exact key order (the parquet struct schema is inferred
    from insertion order, so timestamp must stay between style and camera).

docs: 'Which modules run' is now a table giving each module's on/off flag
(--plan.enabled / --interjections.enabled / --vqa.enabled) and what it
turns off.

Verified: 40 tests pass (incl. test_writer struct round-trip); pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-04 14:18:36 +02:00
parent 7471a6b1ed
commit 973318ef65
3 changed files with 56 additions and 66 deletions
+8 -3
View File
@@ -136,9 +136,14 @@ short manipulation episodes.
### Which modules run
Each module can be turned off independently to iterate on one at a time:
`--plan.enabled`, `--interjections.enabled`, `--vqa.enabled` (all
`true` by default).
Every module is on by default and can be toggled independently (set to
`false` to skip it, e.g. to iterate on one module at a time):
| Flag | Default | Turns off |
| ------------------------- | ------- | ----------------------------------- |
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
| `--interjections.enabled` | `true` | interjections + speech atoms |
| `--vqa.enabled` | `true` | the VQA pairs |
### The VLM (`--vlm.*`)
@@ -71,47 +71,16 @@ class PlanSubtasksMemoryModule:
effective_task = self._resolve_effective_task(record)
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
# free-form n_task_rephrasings.
# free-form n_task_rephrasings; the effective task is always emitted
# first so the rotation covers the source-of-truth phrasing.
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
axes_cfg = self.config.task_aug_axes
if axes_cfg.enabled and effective_task:
variants = self._generate_task_aug_by_axes(effective_task, axes_cfg)
seen: set[str] = set()
ordered = [effective_task, *variants]
for phrasing in ordered:
key = phrasing.strip()
if not key or key in seen:
continue
seen.add(key)
rows.append(
{
"role": "user",
"content": key,
"style": "task_aug",
"timestamp": t0,
"tool_calls": None,
}
)
variants: list[str] | None = None
if self.config.task_aug_axes.enabled and effective_task:
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
elif self.config.n_task_rephrasings > 0 and effective_task:
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
# Include the effective task first so the rotation always covers
# the source-of-truth phrasing, not just synthetic ones.
seen = set()
ordered = [effective_task, *rephrasings]
for phrasing in ordered:
key = phrasing.strip()
if not key or key in seen:
continue
seen.add(key)
rows.append(
{
"role": "user",
"content": key,
"style": "task_aug",
"timestamp": t0,
"tool_calls": None,
}
)
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
if variants is not None:
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
subtask_spans = self._generate_subtasks(record, task=effective_task)
@@ -233,6 +202,21 @@ class PlanSubtasksMemoryModule:
return True
return task.lower() in self._PLACEHOLDER_TASKS
@staticmethod
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
seen: set[str] = set()
rows: list[dict[str, Any]] = []
for phrasing in phrasings:
key = phrasing.strip()
if not key or key in seen:
continue
seen.add(key)
rows.append(
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
)
return rows
# ------------------------------------------------------------------
# VLM call helpers — every plan-module prompt follows the same shape:
# build messages → single VLM call → pull a named field.
@@ -89,6 +89,27 @@ def _row_event_sort_key(row: dict[str, Any]) -> tuple:
)
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
"""Coerce a staged row into the language-column struct shape.
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
writer infers the parquet struct schema from insertion order, so
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
"""
camera = row.get("camera")
validate_camera_field(style, camera)
out: dict[str, Any] = {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
}
if with_timestamp:
out["timestamp"] = float(row["timestamp"])
out["camera"] = None if camera is None else str(camera)
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
return out
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the persistent column's struct shape."""
style = row.get("style")
@@ -100,22 +121,10 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
if "timestamp" not in row:
raise ValueError(f"persistent row missing timestamp: {row!r}")
if "role" not in row:
# Surface a friendly error from the writer rather than letting
# the raw KeyError bubble out of the dict access below — modules
# are expected to always emit ``role``, but the validator
# currently doesn't check this so a future bug would otherwise
# be hard to triage.
# Friendly error from the writer instead of a raw KeyError below;
# the validator doesn't check ``role`` yet.
raise ValueError(f"persistent row missing role: {row!r}")
camera = row.get("camera")
validate_camera_field(style, camera)
return {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
"timestamp": float(row["timestamp"]),
"camera": None if camera is None else str(camera),
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
}
return _normalize_row(row, style, with_timestamp=True)
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
@@ -129,15 +138,7 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
raise ValueError(f"event row with style {style!r} would not route to language_events")
if "role" not in row:
raise ValueError(f"event row missing role: {row!r}")
camera = row.get("camera")
validate_camera_field(style, camera)
return {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
"camera": None if camera is None else str(camera),
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
}
return _normalize_row(row, style, with_timestamp=False)
def _normalize_tool_calls(value: Any) -> list[Any] | None: