mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
fix(annotate): sync language metadata after parquet rewrite
Ensure annotated datasets advertise language columns in meta/info.json so non-streaming dataset loads cast against the rewritten parquet schema. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -129,26 +129,24 @@ class Executor:
|
|||||||
written = self.writer.write_all(records, staging_dir, root)
|
written = self.writer.write_all(records, staging_dir, root)
|
||||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||||
|
|
||||||
# Persist the tool catalog to meta/info.json so chat-template
|
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||||
# consumers (PR 3 SmolVLA2 / Pi0.5 / dataset visualizer) can read
|
# Idempotent and additive: existing user metadata is preserved.
|
||||||
# it via ``LeRobotDatasetMetadata.tools`` (PR 1). Idempotent and
|
self._ensure_annotation_metadata_in_info(root)
|
||||||
# additive: anything the user pre-populated is preserved; we only
|
|
||||||
# ensure the canonical ``say`` schema is present.
|
|
||||||
self._ensure_tools_in_info(root)
|
|
||||||
|
|
||||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||||
|
|
||||||
def _ensure_tools_in_info(self, root: Path) -> None:
|
@staticmethod
|
||||||
"""Write ``meta/info.json["tools"]`` if missing the canonical ``say``.
|
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||||
|
"""Write language features and canonical tools to ``meta/info.json``.
|
||||||
|
|
||||||
Reads any user-declared tools already in ``info.json`` and merges
|
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||||
the canonical ``SAY_TOOL_SCHEMA`` into the list (deduped by
|
``language_events`` to parquet shards. The metadata must advertise
|
||||||
``function.name``). Writes back to disk only if the list
|
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||||
changed.
|
cast against the old schema and fail on the extra parquet columns.
|
||||||
"""
|
"""
|
||||||
import json # noqa: PLC0415
|
import json # noqa: PLC0415
|
||||||
|
|
||||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA # noqa: PLC0415
|
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||||
|
|
||||||
info_path = root / "meta" / "info.json"
|
info_path = root / "meta" / "info.json"
|
||||||
if not info_path.exists():
|
if not info_path.exists():
|
||||||
@@ -159,6 +157,16 @@ class Executor:
|
|||||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
features = info.get("features")
|
||||||
|
if not isinstance(features, dict):
|
||||||
|
features = {}
|
||||||
|
merged_features = {**features, **language_feature_info()}
|
||||||
|
if merged_features != features:
|
||||||
|
info["features"] = merged_features
|
||||||
|
changed = True
|
||||||
|
|
||||||
existing = info.get("tools")
|
existing = info.get("tools")
|
||||||
if not isinstance(existing, list):
|
if not isinstance(existing, list):
|
||||||
existing = []
|
existing = []
|
||||||
@@ -172,9 +180,14 @@ class Executor:
|
|||||||
merged.append(SAY_TOOL_SCHEMA)
|
merged.append(SAY_TOOL_SCHEMA)
|
||||||
if merged != existing:
|
if merged != existing:
|
||||||
info["tools"] = merged
|
info["tools"] = merged
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if changed:
|
||||||
info_path.write_text(json.dumps(info, indent=2))
|
info_path.write_text(json.dumps(info, indent=2))
|
||||||
print(
|
print(
|
||||||
f"[annotate] meta/info.json: tools={[t['function']['name'] for t in merged]}",
|
"[annotate] meta/info.json: "
|
||||||
|
f"language_features={list(language_feature_info())}, "
|
||||||
|
f"tools={[t['function']['name'] for t in merged]}",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
|
|||||||
"content": "where is the cup?",
|
"content": "where is the cup?",
|
||||||
"style": "vqa",
|
"style": "vqa",
|
||||||
"timestamp": 0.4,
|
"timestamp": 0.4,
|
||||||
|
"camera": "observation.images.front",
|
||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -177,6 +178,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
|
|||||||
),
|
),
|
||||||
"style": "vqa",
|
"style": "vqa",
|
||||||
"timestamp": 0.4,
|
"timestamp": 0.4,
|
||||||
|
"camera": "observation.images.front",
|
||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -285,6 +287,56 @@ def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path:
|
|||||||
assert "tools" not in table.column_names
|
assert "tools" not in table.column_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||||
|
fixture_dataset_root: Path, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""Annotated parquet columns must be declared in ``meta/info.json``.
|
||||||
|
|
||||||
|
``LeRobotDataset`` loads non-streaming datasets by casting parquet
|
||||||
|
against metadata-derived HF features. If the annotation writer adds
|
||||||
|
language columns but metadata stays stale, that cast fails with a column
|
||||||
|
mismatch.
|
||||||
|
"""
|
||||||
|
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||||
|
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||||
|
from lerobot.datasets.io_utils import load_info, load_nested_dataset
|
||||||
|
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info
|
||||||
|
|
||||||
|
info_path = fixture_dataset_root / "meta" / "info.json"
|
||||||
|
info = json.loads(info_path.read_text())
|
||||||
|
info["features"] = {
|
||||||
|
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
|
}
|
||||||
|
info_path.write_text(json.dumps(info, indent=2))
|
||||||
|
|
||||||
|
staging_dir = tmp_path / "stage"
|
||||||
|
_stage_episode(
|
||||||
|
staging_dir,
|
||||||
|
0,
|
||||||
|
module_1=[
|
||||||
|
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
records = list(iter_episodes(fixture_dataset_root))
|
||||||
|
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||||
|
|
||||||
|
Executor._ensure_annotation_metadata_in_info(fixture_dataset_root)
|
||||||
|
|
||||||
|
synced = load_info(fixture_dataset_root)
|
||||||
|
for key, feature in language_feature_info().items():
|
||||||
|
assert synced["features"][key] == feature
|
||||||
|
|
||||||
|
hf_features = get_hf_features_from_features(synced["features"])
|
||||||
|
dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features)
|
||||||
|
|
||||||
|
assert LANGUAGE_PERSISTENT in dataset.column_names
|
||||||
|
assert LANGUAGE_EVENTS in dataset.column_names
|
||||||
|
assert len(dataset) == 24
|
||||||
|
|
||||||
|
|
||||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||||
assert atom["role"] == "assistant"
|
assert atom["role"] == "assistant"
|
||||||
|
|||||||
Reference in New Issue
Block a user