fix(annotate): sync language metadata after parquet rewrite

Ensure annotated datasets advertise language columns in meta/info.json so non-streaming dataset loads cast against the rewritten parquet schema.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-04 15:17:15 +00:00
parent 73740ecf4b
commit 8fa8323c91
2 changed files with 79 additions and 14 deletions
@@ -129,26 +129,24 @@ class Executor:
written = self.writer.write_all(records, staging_dir, root) written = self.writer.write_all(records, staging_dir, root)
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True) print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
# Persist the tool catalog to meta/info.json so chat-template # Keep meta/info.json aligned with the parquet schema we just wrote.
# consumers (PR 3 SmolVLA2 / Pi0.5 / dataset visualizer) can read # Idempotent and additive: existing user metadata is preserved.
# it via ``LeRobotDatasetMetadata.tools`` (PR 1). Idempotent and self._ensure_annotation_metadata_in_info(root)
# additive: anything the user pre-populated is preserved; we only
# ensure the canonical ``say`` schema is present.
self._ensure_tools_in_info(root)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report) return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
def _ensure_tools_in_info(self, root: Path) -> None: @staticmethod
"""Write ``meta/info.json["tools"]`` if missing the canonical ``say``. def _ensure_annotation_metadata_in_info(root: Path) -> None:
"""Write language features and canonical tools to ``meta/info.json``.
Reads any user-declared tools already in ``info.json`` and merges ``LanguageColumnsWriter`` adds ``language_persistent`` and
the canonical ``SAY_TOOL_SCHEMA`` into the list (deduped by ``language_events`` to parquet shards. The metadata must advertise
``function.name``). Writes back to disk only if the list those columns too, otherwise non-streaming ``LeRobotDataset`` loads
changed. cast against the old schema and fail on the extra parquet columns.
""" """
import json # noqa: PLC0415 import json # noqa: PLC0415
from lerobot.datasets.language import SAY_TOOL_SCHEMA # noqa: PLC0415 from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
info_path = root / "meta" / "info.json" info_path = root / "meta" / "info.json"
if not info_path.exists(): if not info_path.exists():
@@ -159,6 +157,16 @@ class Executor:
print(f"[annotate] could not read {info_path}: {exc}", flush=True) print(f"[annotate] could not read {info_path}: {exc}", flush=True)
return return
changed = False
features = info.get("features")
if not isinstance(features, dict):
features = {}
merged_features = {**features, **language_feature_info()}
if merged_features != features:
info["features"] = merged_features
changed = True
existing = info.get("tools") existing = info.get("tools")
if not isinstance(existing, list): if not isinstance(existing, list):
existing = [] existing = []
@@ -172,9 +180,14 @@ class Executor:
merged.append(SAY_TOOL_SCHEMA) merged.append(SAY_TOOL_SCHEMA)
if merged != existing: if merged != existing:
info["tools"] = merged info["tools"] = merged
changed = True
if changed:
info_path.write_text(json.dumps(info, indent=2)) info_path.write_text(json.dumps(info, indent=2))
print( print(
f"[annotate] meta/info.json: tools={[t['function']['name'] for t in merged]}", "[annotate] meta/info.json: "
f"language_features={list(language_feature_info())}, "
f"tools={[t['function']['name'] for t in merged]}",
flush=True, flush=True,
) )
+52
View File
@@ -167,6 +167,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
"content": "where is the cup?", "content": "where is the cup?",
"style": "vqa", "style": "vqa",
"timestamp": 0.4, "timestamp": 0.4,
"camera": "observation.images.front",
"tool_calls": None, "tool_calls": None,
}, },
{ {
@@ -177,6 +178,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
), ),
"style": "vqa", "style": "vqa",
"timestamp": 0.4, "timestamp": 0.4,
"camera": "observation.images.front",
"tool_calls": None, "tool_calls": None,
}, },
], ],
@@ -285,6 +287,56 @@ def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path:
assert "tools" not in table.column_names assert "tools" not in table.column_names
def test_annotation_metadata_sync_allows_non_streaming_load(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
"""Annotated parquet columns must be declared in ``meta/info.json``.
``LeRobotDataset`` loads non-streaming datasets by casting parquet
against metadata-derived HF features. If the annotation writer adds
language columns but metadata stays stale, that cast fails with a column
mismatch.
"""
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import load_info, load_nested_dataset
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info
info_path = fixture_dataset_root / "meta" / "info.json"
info = json.loads(info_path.read_text())
info["features"] = {
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
info_path.write_text(json.dumps(info, indent=2))
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
Executor._ensure_annotation_metadata_in_info(fixture_dataset_root)
synced = load_info(fixture_dataset_root)
for key, feature in language_feature_info().items():
assert synced["features"][key] == feature
hf_features = get_hf_features_from_features(synced["features"])
dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features)
assert LANGUAGE_PERSISTENT in dataset.column_names
assert LANGUAGE_EVENTS in dataset.column_names
assert len(dataset) == 24
def test_speech_atom_shape_matches_plan_spec() -> None: def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!") atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant" assert atom["role"] == "assistant"