mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
annotate: dedup task_aug + row-normalization; docs module on/off table
Two behavior-preserving simplifications:
* plan_subtasks_memory.run_episode: the task_aug 'axes' and free-form
branches built identical deduped rows via copy-pasted seen/append
loops. Collapse to one branch that picks the variant source, then a
shared _task_aug_rows() helper does the dedup + row build (-~25 LOC).
* writer: _normalize_persistent_row / _normalize_event_row shared the
same camera-validate + struct construction. Extract _normalize_row(),
keeping the exact key order (the parquet struct schema is inferred
from insertion order, so timestamp must stay between style and camera).
docs: 'Which modules run' is now a table giving each module's on/off flag
(--plan.enabled / --interjections.enabled / --vqa.enabled) and what it
turns off.
Verified: 40 tests pass (incl. test_writer struct round-trip); pre-commit clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -136,9 +136,14 @@ short manipulation episodes.
|
||||
|
||||
### Which modules run
|
||||
|
||||
Each module can be turned off independently to iterate on one at a time:
|
||||
`--plan.enabled`, `--interjections.enabled`, `--vqa.enabled` (all
|
||||
`true` by default).
|
||||
Every module is on by default and can be toggled independently (set to
|
||||
`false` to skip it, e.g. to iterate on one module at a time):
|
||||
|
||||
| Flag | Default | Turns off |
|
||||
| ------------------------- | ------- | ----------------------------------- |
|
||||
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
|
||||
| `--interjections.enabled` | `true` | interjections + speech atoms |
|
||||
| `--vqa.enabled` | `true` | the VQA pairs |
|
||||
|
||||
### The VLM (`--vlm.*`)
|
||||
|
||||
|
||||
@@ -71,47 +71,16 @@ class PlanSubtasksMemoryModule:
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
|
||||
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
|
||||
# free-form n_task_rephrasings.
|
||||
# free-form n_task_rephrasings; the effective task is always emitted
|
||||
# first so the rotation covers the source-of-truth phrasing.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
axes_cfg = self.config.task_aug_axes
|
||||
if axes_cfg.enabled and effective_task:
|
||||
variants = self._generate_task_aug_by_axes(effective_task, axes_cfg)
|
||||
seen: set[str] = set()
|
||||
ordered = [effective_task, *variants]
|
||||
for phrasing in ordered:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": key,
|
||||
"style": "task_aug",
|
||||
"timestamp": t0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
variants: list[str] | None = None
|
||||
if self.config.task_aug_axes.enabled and effective_task:
|
||||
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
|
||||
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
# Include the effective task first so the rotation always covers
|
||||
# the source-of-truth phrasing, not just synthetic ones.
|
||||
seen = set()
|
||||
ordered = [effective_task, *rephrasings]
|
||||
for phrasing in ordered:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": key,
|
||||
"style": "task_aug",
|
||||
"timestamp": t0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
if variants is not None:
|
||||
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
|
||||
@@ -233,6 +202,21 @@ class PlanSubtasksMemoryModule:
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
@staticmethod
|
||||
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
|
||||
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
|
||||
seen: set[str] = set()
|
||||
rows: list[dict[str, Any]] = []
|
||||
for phrasing in phrasings:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
|
||||
)
|
||||
return rows
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers — every plan-module prompt follows the same shape:
|
||||
# build messages → single VLM call → pull a named field.
|
||||
|
||||
@@ -89,6 +89,27 @@ def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
)
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the language-column struct shape.
|
||||
|
||||
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
|
||||
writer infers the parquet struct schema from insertion order, so
|
||||
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
|
||||
"""
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
out: dict[str, Any] = {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
}
|
||||
if with_timestamp:
|
||||
out["timestamp"] = float(row["timestamp"])
|
||||
out["camera"] = None if camera is None else str(camera)
|
||||
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
@@ -100,22 +121,10 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Surface a friendly error from the writer rather than letting
|
||||
# the raw KeyError bubble out of the dict access below — modules
|
||||
# are expected to always emit ``role``, but the validator
|
||||
# currently doesn't check this so a future bug would otherwise
|
||||
# be hard to triage.
|
||||
# Friendly error from the writer instead of a raw KeyError below;
|
||||
# the validator doesn't check ``role`` yet.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"timestamp": float(row["timestamp"]),
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
return _normalize_row(row, style, with_timestamp=True)
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -129,15 +138,7 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
return _normalize_row(row, style, with_timestamp=False)
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
|
||||
Reference in New Issue
Block a user