Files
lerobot/tests/datasets/test_language.py
T
Pepijn 0b06790da0 feat(language): add motion (persistent) and trace (event-only) styles
Promote the previously-reserved motion/trace styles to first-class core
styles. motion routes to language_persistent (it tracks robot state over
time); trace routes to language_events (single-moment annotations).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:21:49 +02:00

100 lines
3.3 KiB
Python

#!/usr/bin/env python
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.io_utils import write_info
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
language_events_arrow_type,
language_feature_info,
language_persistent_arrow_type,
)
from lerobot.datasets.utils import DEFAULT_DATA_PATH
def test_language_arrow_schema_has_expected_fields():
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "tool_calls"]
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS
assert column_for_style(None) == LANGUAGE_EVENTS
def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise")
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
root = tmp_path / "language_dataset"
dataset = empty_lerobot_dataset_factory(
root=root,
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
use_videos=False,
)
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
dataset.save_episode()
dataset.finalize()
persistent = [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
]
event = {
"role": "user",
"content": "what is visible?",
"style": "vqa",
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
df = pd.read_parquet(data_path)
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
df[LANGUAGE_EVENTS] = [[event], []]
df.to_parquet(data_path)
info = dataset.meta.info
info["features"].update(language_feature_info())
write_info(info, root)
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
first = reloaded[0]
second = reloaded[1]
assert first[LANGUAGE_PERSISTENT] == persistent
assert first[LANGUAGE_EVENTS] == [event]
assert second[LANGUAGE_PERSISTENT] == persistent
assert second[LANGUAGE_EVENTS] == []