Files
lerobot/tests/datasets/test_language.py
T
Pepijn c1a0c601e2 feat(language): task_aug style + automatic ${task} rephrasing rotation
Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.

- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
  ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
  ``language_persistent`` so PR 2 writers route it correctly.

- ``language_render.py``: ``_resolve_task`` now consults the persistent
  slice for rows of ``style="task_aug", role="user"``. When any exist
  it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
  Python's randomized hash) so an epoch sees every rephrasing of every
  episode while the same sample still resolves identically across
  reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
  no rephrasings are present, so existing datasets and unannotated runs
  keep their behaviour. Explicit ``task=`` overrides still win.

- Tests: rephrasing coverage across samples, determinism on repeat
  ``sample_idx``, fallback when persistent has no ``task_aug`` rows,
  and explicit override priority.

Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 16:45:39 +02:00

153 lines
5.5 KiB
Python

#!/usr/bin/env python
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.io_utils import write_info
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
VIEW_DEPENDENT_STYLES,
column_for_style,
is_view_dependent_style,
language_events_arrow_type,
language_feature_info,
language_persistent_arrow_type,
validate_camera_field,
)
from lerobot.datasets.utils import DEFAULT_DATA_PATH
def test_language_arrow_schema_has_expected_fields():
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == [
"role",
"content",
"style",
"timestamp",
"camera",
"tool_calls",
]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS
assert column_for_style(None) == LANGUAGE_EVENTS
def test_view_dependent_styles():
# motion lives in PERSISTENT_STYLES and is described in robot-frame
# (joint / Cartesian) terms, so it is NOT view-dependent. Only vqa
# (event) and trace (event, pixel-trajectory) carry a camera tag.
assert {"vqa", "trace"} == VIEW_DEPENDENT_STYLES
assert is_view_dependent_style("vqa")
assert is_view_dependent_style("trace")
assert not is_view_dependent_style("motion")
assert not is_view_dependent_style("subtask")
assert not is_view_dependent_style("plan")
assert not is_view_dependent_style("interjection")
assert not is_view_dependent_style(None)
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
validate_camera_field("vqa", "observation.images.top")
validate_camera_field("trace", "observation.images.front")
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("vqa", None)
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("trace", "")
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
validate_camera_field("subtask", None)
validate_camera_field("plan", None)
validate_camera_field("memory", None)
validate_camera_field("motion", None)
validate_camera_field("interjection", None)
validate_camera_field(None, None)
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("subtask", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("motion", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("interjection", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field(None, "observation.images.top")
def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise")
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
root = tmp_path / "language_dataset"
dataset = empty_lerobot_dataset_factory(
root=root,
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
use_videos=False,
)
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
dataset.save_episode()
dataset.finalize()
persistent = [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
]
event = {
"role": "user",
"content": "what is visible?",
"style": "vqa",
"camera": "observation.images.top",
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
df = pd.read_parquet(data_path)
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
df[LANGUAGE_EVENTS] = [[event], []]
df.to_parquet(data_path)
info = dataset.meta.info
info["features"].update(language_feature_info())
write_info(info, root)
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
first = reloaded[0]
second = reloaded[1]
assert first[LANGUAGE_PERSISTENT] == persistent
assert first[LANGUAGE_EVENTS] == [event]
assert second[LANGUAGE_PERSISTENT] == persistent
assert second[LANGUAGE_EVENTS] == []