From 8fa8323c91f76e59a74b763e766903617ee44cd1 Mon Sep 17 00:00:00 2001 From: pepijn Date: Mon, 4 May 2026 15:17:15 +0000 Subject: [PATCH] fix(annotate): sync language metadata after parquet rewrite Ensure annotated datasets advertise language columns in meta/info.json so non-streaming dataset loads cast against the rewritten parquet schema. Co-authored-by: Cursor --- .../steerable_pipeline/executor.py | 41 ++++++++++----- tests/annotations/test_writer.py | 52 +++++++++++++++++++ 2 files changed, 79 insertions(+), 14 deletions(-) diff --git a/src/lerobot/annotations/steerable_pipeline/executor.py b/src/lerobot/annotations/steerable_pipeline/executor.py index 79a7f1614..71578c2d7 100644 --- a/src/lerobot/annotations/steerable_pipeline/executor.py +++ b/src/lerobot/annotations/steerable_pipeline/executor.py @@ -129,26 +129,24 @@ class Executor: written = self.writer.write_all(records, staging_dir, root) print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True) - # Persist the tool catalog to meta/info.json so chat-template - # consumers (PR 3 SmolVLA2 / Pi0.5 / dataset visualizer) can read - # it via ``LeRobotDatasetMetadata.tools`` (PR 1). Idempotent and - # additive: anything the user pre-populated is preserved; we only - # ensure the canonical ``say`` schema is present. - self._ensure_tools_in_info(root) + # Keep meta/info.json aligned with the parquet schema we just wrote. + # Idempotent and additive: existing user metadata is preserved. + self._ensure_annotation_metadata_in_info(root) return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report) - def _ensure_tools_in_info(self, root: Path) -> None: - """Write ``meta/info.json["tools"]`` if missing the canonical ``say``. + @staticmethod + def _ensure_annotation_metadata_in_info(root: Path) -> None: + """Write language features and canonical tools to ``meta/info.json``. - Reads any user-declared tools already in ``info.json`` and merges - the canonical ``SAY_TOOL_SCHEMA`` into the list (deduped by - ``function.name``). Writes back to disk only if the list - changed. + ``LanguageColumnsWriter`` adds ``language_persistent`` and + ``language_events`` to parquet shards. The metadata must advertise + those columns too, otherwise non-streaming ``LeRobotDataset`` loads + cast against the old schema and fail on the extra parquet columns. """ import json # noqa: PLC0415 - from lerobot.datasets.language import SAY_TOOL_SCHEMA # noqa: PLC0415 + from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415 info_path = root / "meta" / "info.json" if not info_path.exists(): @@ -159,6 +157,16 @@ class Executor: print(f"[annotate] could not read {info_path}: {exc}", flush=True) return + changed = False + + features = info.get("features") + if not isinstance(features, dict): + features = {} + merged_features = {**features, **language_feature_info()} + if merged_features != features: + info["features"] = merged_features + changed = True + existing = info.get("tools") if not isinstance(existing, list): existing = [] @@ -172,9 +180,14 @@ class Executor: merged.append(SAY_TOOL_SCHEMA) if merged != existing: info["tools"] = merged + changed = True + + if changed: info_path.write_text(json.dumps(info, indent=2)) print( - f"[annotate] meta/info.json: tools={[t['function']['name'] for t in merged]}", + "[annotate] meta/info.json: " + f"language_features={list(language_feature_info())}, " + f"tools={[t['function']['name'] for t in merged]}", flush=True, ) diff --git a/tests/annotations/test_writer.py b/tests/annotations/test_writer.py index c6ff03178..9a736cc1c 100644 --- a/tests/annotations/test_writer.py +++ b/tests/annotations/test_writer.py @@ -167,6 +167,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No "content": "where is the cup?", "style": "vqa", "timestamp": 0.4, + "camera": "observation.images.front", "tool_calls": None, }, { @@ -177,6 +178,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No ), "style": "vqa", "timestamp": 0.4, + "camera": "observation.images.front", "tool_calls": None, }, ], @@ -285,6 +287,56 @@ def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: assert "tools" not in table.column_names +def test_annotation_metadata_sync_allows_non_streaming_load( + fixture_dataset_root: Path, tmp_path: Path +) -> None: + """Annotated parquet columns must be declared in ``meta/info.json``. + + ``LeRobotDataset`` loads non-streaming datasets by casting parquet + against metadata-derived HF features. If the annotation writer adds + language columns but metadata stays stale, that cast fails with a column + mismatch. + """ + from lerobot.annotations.steerable_pipeline.executor import Executor + from lerobot.datasets.feature_utils import get_hf_features_from_features + from lerobot.datasets.io_utils import load_info, load_nested_dataset + from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info + + info_path = fixture_dataset_root / "meta" / "info.json" + info = json.loads(info_path.read_text()) + info["features"] = { + "episode_index": {"dtype": "int64", "shape": (1,), "names": None}, + "frame_index": {"dtype": "int64", "shape": (1,), "names": None}, + "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, + "task_index": {"dtype": "int64", "shape": (1,), "names": None}, + } + info_path.write_text(json.dumps(info, indent=2)) + + staging_dir = tmp_path / "stage" + _stage_episode( + staging_dir, + 0, + module_1=[ + {"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None} + ], + ) + records = list(iter_episodes(fixture_dataset_root)) + LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) + + Executor._ensure_annotation_metadata_in_info(fixture_dataset_root) + + synced = load_info(fixture_dataset_root) + for key, feature in language_feature_info().items(): + assert synced["features"][key] == feature + + hf_features = get_hf_features_from_features(synced["features"]) + dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features) + + assert LANGUAGE_PERSISTENT in dataset.column_names + assert LANGUAGE_EVENTS in dataset.column_names + assert len(dataset) == 24 + + def test_speech_atom_shape_matches_plan_spec() -> None: atom = speech_atom(2.5, "I'm cleaning up!") assert atom["role"] == "assistant"