mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(language): task_aug style + automatic ${task} rephrasing rotation
Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.
- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
``language_persistent`` so PR 2 writers route it correctly.
- ``language_render.py``: ``_resolve_task`` now consults the persistent
slice for rows of ``style="task_aug", role="user"``. When any exist
it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
Python's randomized hash) so an epoch sees every rephrasing of every
episode while the same sample still resolves identically across
reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
no rephrasings are present, so existing datasets and unannotated runs
keep their behaviour. Explicit ``task=`` overrides still win.
- Tests: rephrasing coverage across samples, determinism on repeat
``sample_idx``, fallback when persistent has no ``task_aug`` rows,
and explicit override priority.
Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -27,11 +27,20 @@ LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
|||||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||||
|
|
||||||
CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"}
|
CORE_STYLES = {
|
||||||
|
"subtask",
|
||||||
|
"plan",
|
||||||
|
"memory",
|
||||||
|
"motion",
|
||||||
|
"interjection",
|
||||||
|
"vqa",
|
||||||
|
"trace",
|
||||||
|
"task_aug",
|
||||||
|
}
|
||||||
EXTENDED_STYLES = set()
|
EXTENDED_STYLES = set()
|
||||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||||
|
|
||||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"}
|
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||||
|
|
||||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ def render_sample(
|
|||||||
persistent=persistent_rows,
|
persistent=persistent_rows,
|
||||||
events=event_rows,
|
events=event_rows,
|
||||||
t=t,
|
t=t,
|
||||||
|
sample_idx=sample_idx,
|
||||||
task=task,
|
task=task,
|
||||||
dataset_ctx=dataset_ctx,
|
dataset_ctx=dataset_ctx,
|
||||||
)
|
)
|
||||||
@@ -232,21 +233,65 @@ def _resolve_bindings(
|
|||||||
persistent: Sequence[LanguageRow],
|
persistent: Sequence[LanguageRow],
|
||||||
events: Sequence[LanguageRow],
|
events: Sequence[LanguageRow],
|
||||||
t: float,
|
t: float,
|
||||||
|
sample_idx: int,
|
||||||
task: str | None,
|
task: str | None,
|
||||||
dataset_ctx: Any | None,
|
dataset_ctx: Any | None,
|
||||||
) -> dict[str, LanguageRow | str | None]:
|
) -> dict[str, LanguageRow | str | None]:
|
||||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||||
bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)}
|
bindings: dict[str, LanguageRow | str | None] = {
|
||||||
|
"task": _resolve_task(
|
||||||
|
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
|
||||||
|
),
|
||||||
|
}
|
||||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||||
for name, spec in specs.items():
|
for name, spec in specs.items():
|
||||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||||
return bindings
|
return bindings
|
||||||
|
|
||||||
|
|
||||||
def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None:
|
def _resolve_task(
|
||||||
"""Return ``task`` if set, otherwise look it up on ``dataset_ctx``."""
|
task: str | None,
|
||||||
|
dataset_ctx: Any | None,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow] = (),
|
||||||
|
sample_idx: int = 0,
|
||||||
|
) -> str | None:
|
||||||
|
"""Return the task string for ``sample_idx``.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
|
||||||
|
1. Explicit ``task`` override (caller-supplied) wins.
|
||||||
|
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||||
|
deterministically pick one by ``sample_idx`` so each frame of an
|
||||||
|
episode rotates through the available rephrasings across an epoch.
|
||||||
|
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||||
|
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||||
|
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||||
|
and falls back to the canonical task otherwise. Recipes that want
|
||||||
|
the literal canonical task can override the binding.
|
||||||
|
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||||
|
backed by ``meta/tasks.parquet``).
|
||||||
|
"""
|
||||||
if task is not None:
|
if task is not None:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
aug_rows = [
|
||||||
|
r
|
||||||
|
for r in persistent
|
||||||
|
if r.get("style") == "task_aug" and r.get("role") == "user"
|
||||||
|
]
|
||||||
|
if aug_rows:
|
||||||
|
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||||
|
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||||
|
# is process-randomized).
|
||||||
|
digest = hashlib.blake2b(
|
||||||
|
f"task_aug:{sample_idx}".encode(), digest_size=8
|
||||||
|
).digest()
|
||||||
|
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||||
|
chosen = aug_rows[idx].get("content")
|
||||||
|
if chosen:
|
||||||
|
return str(chosen)
|
||||||
|
|
||||||
if dataset_ctx is None:
|
if dataset_ctx is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(dataset_ctx, dict):
|
if isinstance(dataset_ctx, dict):
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def test_language_arrow_schema_has_expected_fields():
|
|||||||
|
|
||||||
|
|
||||||
def test_style_registry_routes_columns():
|
def test_style_registry_routes_columns():
|
||||||
assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES
|
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
|
||||||
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
||||||
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
||||||
|
|
||||||
@@ -51,6 +51,7 @@ def test_style_registry_routes_columns():
|
|||||||
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
||||||
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
||||||
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
||||||
|
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
|
||||||
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
||||||
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
||||||
assert column_for_style("trace") == LANGUAGE_EVENTS
|
assert column_for_style("trace") == LANGUAGE_EVENTS
|
||||||
|
|||||||
@@ -289,6 +289,87 @@ def test_per_camera_blend_renders_both_views():
|
|||||||
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||||
|
rephrasings = [
|
||||||
|
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
|
||||||
|
]
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# No explicit task override → resolver consults persistent rows.
|
||||||
|
seen: set[str] = set()
|
||||||
|
for sample_idx in range(64):
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=rephrasings,
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=sample_idx,
|
||||||
|
dataset_ctx={"task": "canonical kitchen task"},
|
||||||
|
)
|
||||||
|
seen.add(rendered["messages"][0]["content"])
|
||||||
|
# Every rephrasing should be reachable across enough samples.
|
||||||
|
assert seen == {r["content"] for r in rephrasings}
|
||||||
|
# Same sample_idx → same pick (determinism).
|
||||||
|
a = render_sample(
|
||||||
|
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||||
|
dataset_ctx={"task": "canonical"},
|
||||||
|
)
|
||||||
|
b = render_sample(
|
||||||
|
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||||
|
dataset_ctx={"task": "canonical"},
|
||||||
|
)
|
||||||
|
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=PERSISTENT, # no task_aug rows
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=0,
|
||||||
|
dataset_ctx={"task": "clean the kitchen"},
|
||||||
|
)
|
||||||
|
assert rendered["messages"][0]["content"] == "clean the kitchen"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_task_explicit_override_beats_rephrasings():
|
||||||
|
rephrasings = [
|
||||||
|
persistent_row("user", "rephrased one", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "rephrased two", "task_aug", 0.0),
|
||||||
|
]
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=rephrasings,
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=0,
|
||||||
|
task="explicit override wins",
|
||||||
|
dataset_ctx={"task": "canonical"},
|
||||||
|
)
|
||||||
|
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||||
|
|
||||||
|
|
||||||
def test_canonical_recipe_can_render_low_level_branch():
|
def test_canonical_recipe_can_render_low_level_branch():
|
||||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||||
|
|||||||
Reference in New Issue
Block a user