diff --git a/src/lerobot/datasets/language.py b/src/lerobot/datasets/language.py index cc52835cb..bce1b33aa 100644 --- a/src/lerobot/datasets/language.py +++ b/src/lerobot/datasets/language.py @@ -27,11 +27,20 @@ LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS) PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls") EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls") -CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"} +CORE_STYLES = { + "subtask", + "plan", + "memory", + "motion", + "interjection", + "vqa", + "trace", + "task_aug", +} EXTENDED_STYLES = set() STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES -PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"} +PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"} EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"} # Styles whose ``content`` is grounded in a specific camera view. Rows of these diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index 42cd03a9a..1f4ed2749 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -198,6 +198,7 @@ def render_sample( persistent=persistent_rows, events=event_rows, t=t, + sample_idx=sample_idx, task=task, dataset_ctx=dataset_ctx, ) @@ -232,21 +233,65 @@ def _resolve_bindings( persistent: Sequence[LanguageRow], events: Sequence[LanguageRow], t: float, + sample_idx: int, task: str | None, dataset_ctx: Any | None, ) -> dict[str, LanguageRow | str | None]: """Resolve every binding in ``recipe`` (plus ``task``) at time ``t``.""" - bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)} + bindings: dict[str, LanguageRow | str | None] = { + "task": _resolve_task( + task, dataset_ctx, persistent=persistent, sample_idx=sample_idx + ), + } specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})} for name, spec in specs.items(): bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t) return bindings -def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None: - """Return ``task`` if set, otherwise look it up on ``dataset_ctx``.""" +def _resolve_task( + task: str | None, + dataset_ctx: Any | None, + *, + persistent: Sequence[LanguageRow] = (), + sample_idx: int = 0, +) -> str | None: + """Return the task string for ``sample_idx``. + + Resolution order: + + 1. Explicit ``task`` override (caller-supplied) wins. + 2. If ``persistent`` contains rows of style ``task_aug`` (role=user), + deterministically pick one by ``sample_idx`` so each frame of an + episode rotates through the available rephrasings across an epoch. + This realizes Xiao 2022 / CAST-style task-prompt diversity without + changing ``meta/tasks.parquet`` and without forcing recipes to opt + in: ``${task}`` automatically picks a rephrasing when one exists, + and falls back to the canonical task otherwise. Recipes that want + the literal canonical task can override the binding. + 3. Otherwise read the canonical task from ``dataset_ctx`` (which is + backed by ``meta/tasks.parquet``). + """ if task is not None: return task + + aug_rows = [ + r + for r in persistent + if r.get("style") == "task_aug" and r.get("role") == "user" + ] + if aug_rows: + # Deterministic, blake2b-based pick keyed on sample_idx so the + # rotation is reproducible across runs (Python's built-in ``hash`` + # is process-randomized). + digest = hashlib.blake2b( + f"task_aug:{sample_idx}".encode(), digest_size=8 + ).digest() + idx = int.from_bytes(digest, "big") % len(aug_rows) + chosen = aug_rows[idx].get("content") + if chosen: + return str(chosen) + if dataset_ctx is None: return None if isinstance(dataset_ctx, dict): diff --git a/tests/datasets/test_language.py b/tests/datasets/test_language.py index 5c354a5e4..5d1cac80f 100644 --- a/tests/datasets/test_language.py +++ b/tests/datasets/test_language.py @@ -43,7 +43,7 @@ def test_language_arrow_schema_has_expected_fields(): def test_style_registry_routes_columns(): - assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES + assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY @@ -51,6 +51,7 @@ def test_style_registry_routes_columns(): assert column_for_style("plan") == LANGUAGE_PERSISTENT assert column_for_style("memory") == LANGUAGE_PERSISTENT assert column_for_style("motion") == LANGUAGE_PERSISTENT + assert column_for_style("task_aug") == LANGUAGE_PERSISTENT assert column_for_style("interjection") == LANGUAGE_EVENTS assert column_for_style("vqa") == LANGUAGE_EVENTS assert column_for_style("trace") == LANGUAGE_EVENTS diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py index ddf8ca263..a7bc026ca 100644 --- a/tests/datasets/test_language_render.py +++ b/tests/datasets/test_language_render.py @@ -289,6 +289,87 @@ def test_per_camera_blend_renders_both_views(): assert rendered_wrist["messages"][1]["content"] == '{"count": 1}' +def test_resolve_task_picks_rephrasing_deterministically_per_sample(): + rephrasings = [ + persistent_row("user", "tidy the kitchen", "task_aug", 0.0), + persistent_row("user", "please clean up the kitchen", "task_aug", 0.0), + persistent_row("user", "kitchen needs tidying", "task_aug", 0.0), + persistent_row("user", "make the kitchen clean", "task_aug", 0.0), + ] + recipe = TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="ok", stream="high_level", target=True), + ] + ) + + # No explicit task override → resolver consults persistent rows. + seen: set[str] = set() + for sample_idx in range(64): + rendered = render_sample( + recipe=recipe, + persistent=rephrasings, + events=[], + t=0.0, + sample_idx=sample_idx, + dataset_ctx={"task": "canonical kitchen task"}, + ) + seen.add(rendered["messages"][0]["content"]) + # Every rephrasing should be reachable across enough samples. + assert seen == {r["content"] for r in rephrasings} + # Same sample_idx → same pick (determinism). + a = render_sample( + recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42, + dataset_ctx={"task": "canonical"}, + ) + b = render_sample( + recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42, + dataset_ctx={"task": "canonical"}, + ) + assert a["messages"][0]["content"] == b["messages"][0]["content"] + + +def test_resolve_task_falls_back_to_canonical_without_rephrasings(): + recipe = TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="ok", stream="high_level", target=True), + ] + ) + rendered = render_sample( + recipe=recipe, + persistent=PERSISTENT, # no task_aug rows + events=[], + t=0.0, + sample_idx=0, + dataset_ctx={"task": "clean the kitchen"}, + ) + assert rendered["messages"][0]["content"] == "clean the kitchen" + + +def test_resolve_task_explicit_override_beats_rephrasings(): + rephrasings = [ + persistent_row("user", "rephrased one", "task_aug", 0.0), + persistent_row("user", "rephrased two", "task_aug", 0.0), + ] + recipe = TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="ok", stream="high_level", target=True), + ] + ) + rendered = render_sample( + recipe=recipe, + persistent=rephrasings, + events=[], + t=0.0, + sample_idx=0, + task="explicit override wins", + dataset_ctx={"task": "canonical"}, + ) + assert rendered["messages"][0]["content"] == "explicit override wins" + + def test_canonical_recipe_can_render_low_level_branch(): recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml")) low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})