mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
feat(language): task_aug style + automatic ${task} rephrasing rotation
Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.
- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
``language_persistent`` so PR 2 writers route it correctly.
- ``language_render.py``: ``_resolve_task`` now consults the persistent
slice for rows of ``style="task_aug", role="user"``. When any exist
it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
Python's randomized hash) so an epoch sees every rephrasing of every
episode while the same sample still resolves identically across
reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
no rephrasings are present, so existing datasets and unannotated runs
keep their behaviour. Explicit ``task=`` overrides still win.
- Tests: rephrasing coverage across samples, determinism on repeat
``sample_idx``, fallback when persistent has no ``task_aug`` rows,
and explicit override priority.
Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -289,6 +289,87 @@ def test_per_camera_blend_renders_both_views():
|
||||
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
||||
|
||||
|
||||
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||
rephrasings = [
|
||||
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
|
||||
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
|
||||
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
|
||||
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
|
||||
]
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
|
||||
# No explicit task override → resolver consults persistent rows.
|
||||
seen: set[str] = set()
|
||||
for sample_idx in range(64):
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=sample_idx,
|
||||
dataset_ctx={"task": "canonical kitchen task"},
|
||||
)
|
||||
seen.add(rendered["messages"][0]["content"])
|
||||
# Every rephrasing should be reachable across enough samples.
|
||||
assert seen == {r["content"] for r in rephrasings}
|
||||
# Same sample_idx → same pick (determinism).
|
||||
a = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
b = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||
|
||||
|
||||
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT, # no task_aug rows
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "clean the kitchen"},
|
||||
)
|
||||
assert rendered["messages"][0]["content"] == "clean the kitchen"
|
||||
|
||||
|
||||
def test_resolve_task_explicit_override_beats_rephrasings():
|
||||
rephrasings = [
|
||||
persistent_row("user", "rephrased one", "task_aug", 0.0),
|
||||
persistent_row("user", "rephrased two", "task_aug", 0.0),
|
||||
]
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
task="explicit override wins",
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||
|
||||
|
||||
def test_canonical_recipe_can_render_low_level_branch():
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||
|
||||
Reference in New Issue
Block a user