diff --git a/src/lerobot/datasets/language.py b/src/lerobot/datasets/language.py index cc0f70bf9..b8cc649bf 100644 --- a/src/lerobot/datasets/language.py +++ b/src/lerobot/datasets/language.py @@ -27,12 +27,12 @@ LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS) PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls") EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls") -CORE_STYLES = {"subtask", "plan", "memory", "interjection", "vqa"} +CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"} EXTENDED_STYLES = set() STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES -PERSISTENT_STYLES = {"subtask", "plan", "memory"} -EVENT_ONLY_STYLES = {"interjection", "vqa"} +PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"} +EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"} LanguageColumn = Literal["language_persistent", "language_events"] diff --git a/tests/datasets/test_language.py b/tests/datasets/test_language.py index b5a6ce8ee..4f4066249 100644 --- a/tests/datasets/test_language.py +++ b/tests/datasets/test_language.py @@ -33,15 +33,17 @@ def test_language_arrow_schema_has_expected_fields(): def test_style_registry_routes_columns(): - assert {"subtask", "plan", "memory"} == PERSISTENT_STYLES - assert {"interjection", "vqa"} == EVENT_ONLY_STYLES + assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES + assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY assert column_for_style("subtask") == LANGUAGE_PERSISTENT assert column_for_style("plan") == LANGUAGE_PERSISTENT assert column_for_style("memory") == LANGUAGE_PERSISTENT + assert column_for_style("motion") == LANGUAGE_PERSISTENT assert column_for_style("interjection") == LANGUAGE_EVENTS assert column_for_style("vqa") == LANGUAGE_EVENTS + assert column_for_style("trace") == LANGUAGE_EVENTS assert column_for_style(None) == LANGUAGE_EVENTS