feat(language): task_aug style + automatic ${task} rephrasing rotation

Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.

- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
  ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
  ``language_persistent`` so PR 2 writers route it correctly.

- ``language_render.py``: ``_resolve_task`` now consults the persistent
  slice for rows of ``style="task_aug", role="user"``. When any exist
  it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
  Python's randomized hash) so an epoch sees every rephrasing of every
  episode while the same sample still resolves identically across
  reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
  no rephrasings are present, so existing datasets and unannotated runs
  keep their behaviour. Explicit ``task=`` overrides still win.

- Tests: rephrasing coverage across samples, determinism on repeat
  ``sample_idx``, fallback when persistent has no ``task_aug`` rows,
  and explicit override priority.

Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-30 16:45:39 +02:00
parent 1ca38d9748
commit c1a0c601e2
4 changed files with 142 additions and 6 deletions
+11 -2
View File
@@ -27,11 +27,20 @@ LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls") PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls") EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"} CORE_STYLES = {
"subtask",
"plan",
"memory",
"motion",
"interjection",
"vqa",
"trace",
"task_aug",
}
EXTENDED_STYLES = set() EXTENDED_STYLES = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"} PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"} EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these # Styles whose ``content`` is grounded in a specific camera view. Rows of these
+48 -3
View File
@@ -198,6 +198,7 @@ def render_sample(
persistent=persistent_rows, persistent=persistent_rows,
events=event_rows, events=event_rows,
t=t, t=t,
sample_idx=sample_idx,
task=task, task=task,
dataset_ctx=dataset_ctx, dataset_ctx=dataset_ctx,
) )
@@ -232,21 +233,65 @@ def _resolve_bindings(
persistent: Sequence[LanguageRow], persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow], events: Sequence[LanguageRow],
t: float, t: float,
sample_idx: int,
task: str | None, task: str | None,
dataset_ctx: Any | None, dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]: ) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``.""" """Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)} bindings: dict[str, LanguageRow | str | None] = {
"task": _resolve_task(
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
),
}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})} specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items(): for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t) bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings return bindings
def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None: def _resolve_task(
"""Return ``task`` if set, otherwise look it up on ``dataset_ctx``.""" task: str | None,
dataset_ctx: Any | None,
*,
persistent: Sequence[LanguageRow] = (),
sample_idx: int = 0,
) -> str | None:
"""Return the task string for ``sample_idx``.
Resolution order:
1. Explicit ``task`` override (caller-supplied) wins.
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
deterministically pick one by ``sample_idx`` so each frame of an
episode rotates through the available rephrasings across an epoch.
This realizes Xiao 2022 / CAST-style task-prompt diversity without
changing ``meta/tasks.parquet`` and without forcing recipes to opt
in: ``${task}`` automatically picks a rephrasing when one exists,
and falls back to the canonical task otherwise. Recipes that want
the literal canonical task can override the binding.
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
backed by ``meta/tasks.parquet``).
"""
if task is not None: if task is not None:
return task return task
aug_rows = [
r
for r in persistent
if r.get("style") == "task_aug" and r.get("role") == "user"
]
if aug_rows:
# Deterministic, blake2b-based pick keyed on sample_idx so the
# rotation is reproducible across runs (Python's built-in ``hash``
# is process-randomized).
digest = hashlib.blake2b(
f"task_aug:{sample_idx}".encode(), digest_size=8
).digest()
idx = int.from_bytes(digest, "big") % len(aug_rows)
chosen = aug_rows[idx].get("content")
if chosen:
return str(chosen)
if dataset_ctx is None: if dataset_ctx is None:
return None return None
if isinstance(dataset_ctx, dict): if isinstance(dataset_ctx, dict):
+2 -1
View File
@@ -43,7 +43,7 @@ def test_language_arrow_schema_has_expected_fields():
def test_style_registry_routes_columns(): def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
@@ -51,6 +51,7 @@ def test_style_registry_routes_columns():
assert column_for_style("plan") == LANGUAGE_PERSISTENT assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS assert column_for_style("trace") == LANGUAGE_EVENTS
+81
View File
@@ -289,6 +289,87 @@ def test_per_camera_blend_renders_both_views():
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}' assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
rephrasings = [
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
# No explicit task override → resolver consults persistent rows.
seen: set[str] = set()
for sample_idx in range(64):
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=sample_idx,
dataset_ctx={"task": "canonical kitchen task"},
)
seen.add(rendered["messages"][0]["content"])
# Every rephrasing should be reachable across enough samples.
assert seen == {r["content"] for r in rephrasings}
# Same sample_idx → same pick (determinism).
a = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
dataset_ctx={"task": "canonical"},
)
b = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
dataset_ctx={"task": "canonical"},
)
assert a["messages"][0]["content"] == b["messages"][0]["content"]
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT, # no task_aug rows
events=[],
t=0.0,
sample_idx=0,
dataset_ctx={"task": "clean the kitchen"},
)
assert rendered["messages"][0]["content"] == "clean the kitchen"
def test_resolve_task_explicit_override_beats_rephrasings():
rephrasings = [
persistent_row("user", "rephrased one", "task_aug", 0.0),
persistent_row("user", "rephrased two", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=0,
task="explicit override wins",
dataset_ctx={"task": "canonical"},
)
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_canonical_recipe_can_render_low_level_branch(): def test_canonical_recipe_can_render_low_level_branch():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml")) recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]}) low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})