mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
feat: language annotation pipeline (PR 2/3)
Adds the steerable annotation pipeline (`lerobot-annotate`) that populates the `language_persistent` and `language_events` columns introduced in PR 1 directly into `data/chunk-*/file-*.parquet`. No flavor namespace, no sidecar tree. Modules produced: - Module 1 (plan_subtasks_memory): Pi0.7-style subtasks, plan (init + refresh on interjection), MEM-style memory at subtask boundaries. - Module 2 (interjections_and_speech): t=0 speech-only acknowledgement, mid-episode paired interjection + speech tool-call atom. - Module 3 (general_vqa): bbox/keypoint/count/attribute/spatial pairs at configurable cadence with one-retry JSON validation. Writer enforces: per-episode persistent identity, exact-frame event timestamps, column routing per `column_for_style`, dataset-level `tools` column with the `say` schema, drops legacy `subtask_index`. Validator runs against staged JSONL artifacts before the writer rewrites parquet. Adds `lerobot-annotate` console script, `annotations` extra (datatrove + optional vllm), `make annotation-e2e` opt-in smoke target, and `docs/source/annotation_pipeline.mdx`. Branched from PR 1 (`feat/language-columns`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helpers shared across annotation-pipeline tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
|
||||
def make_canned_responder(
|
||||
responses_by_marker: dict[str, Any],
|
||||
default: Any = None,
|
||||
) -> StubVlmClient:
|
||||
"""Return a stub that picks a response by inspecting the user prompt.
|
||||
|
||||
For each call the responder examines the last user-message text and
|
||||
returns the response keyed by the first marker substring it contains.
|
||||
Falls back to ``default`` if no marker matches.
|
||||
"""
|
||||
|
||||
def responder(messages: list[dict[str, Any]]) -> Any:
|
||||
last_user_text = ""
|
||||
for message in messages:
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
last_user_text = content
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
last_user_text = block.get("text", "")
|
||||
for marker, response in responses_by_marker.items():
|
||||
if marker in last_user_text:
|
||||
return response
|
||||
return default
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def encode_vqa_answer(payload: dict[str, Any]) -> str:
|
||||
return json.dumps(payload, sort_keys=True)
|
||||
@@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared fixtures for annotation-pipeline tests.
|
||||
|
||||
Builds a minimal LeRobot-shaped dataset on disk so writer/validator tests
|
||||
can exercise real parquet reads and writes without needing a checked-in
|
||||
LFS dataset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_episode_table(
|
||||
episode_index: int,
|
||||
num_frames: int,
|
||||
*,
|
||||
fps: int = 10,
|
||||
task_index: int = 0,
|
||||
) -> pa.Table:
|
||||
timestamps = [round(i / fps, 6) for i in range(num_frames)]
|
||||
frame_indices = list(range(num_frames))
|
||||
return pa.Table.from_pydict(
|
||||
{
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": frame_indices,
|
||||
"timestamp": timestamps,
|
||||
"task_index": [task_index] * num_frames,
|
||||
"subtask_index": [0] * num_frames, # legacy column the writer must drop
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_dataset(root: Path, episode_specs: list[tuple[int, int, str]], *, fps: int = 10) -> Path:
|
||||
"""Create a fixture dataset under ``root``.
|
||||
|
||||
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
|
||||
Each episode goes into its own ``data/chunk-000/file-{ep:03d}.parquet``
|
||||
so the writer's per-shard rewrite path is exercised.
|
||||
"""
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
tasks = {}
|
||||
for episode_index, num_frames, task_text in episode_specs:
|
||||
task_index = len(tasks)
|
||||
if task_text not in tasks.values():
|
||||
tasks[task_index] = task_text
|
||||
else:
|
||||
task_index = next(k for k, v in tasks.items() if v == task_text)
|
||||
table = _make_episode_table(episode_index, num_frames, fps=fps, task_index=task_index)
|
||||
path = data_dir / f"file-{episode_index:03d}.parquet"
|
||||
pq.write_table(table, path)
|
||||
|
||||
meta_dir = root / "meta"
|
||||
meta_dir.mkdir(parents=True, exist_ok=True)
|
||||
tasks_table = pa.Table.from_pydict(
|
||||
{
|
||||
"task_index": list(tasks.keys()),
|
||||
"task": list(tasks.values()),
|
||||
}
|
||||
)
|
||||
pq.write_table(tasks_table, meta_dir / "tasks.parquet")
|
||||
|
||||
info = {
|
||||
"codebase_version": "v3.1",
|
||||
"fps": fps,
|
||||
"total_episodes": len(episode_specs),
|
||||
}
|
||||
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
|
||||
|
||||
return root
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
|
||||
return _build_dataset(
|
||||
tmp_path / "ds",
|
||||
episode_specs=[
|
||||
(0, 12, "Could you tidy the kitchen please?"),
|
||||
(1, 12, "Please clean up the kitchen"),
|
||||
],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_episode_root(tmp_path: Path) -> Path:
|
||||
return _build_dataset(
|
||||
tmp_path / "ds_one",
|
||||
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Opt-in E2E smoke run for ``make annotation-e2e``.
|
||||
|
||||
Builds the same fixture used by the pytest suite, runs the full
|
||||
annotation pipeline against it with a stub VLM, and prints a short report.
|
||||
This is intentionally not a pytest test — it exercises the CLI plumbing
|
||||
without depending on conftest.py fixtures.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
|
||||
|
||||
def _build_dataset(root: Path) -> Path:
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
n = 30
|
||||
timestamps = [round(i / 10, 6) for i in range(n)]
|
||||
table = pa.Table.from_pydict(
|
||||
{
|
||||
"episode_index": [0] * n,
|
||||
"frame_index": list(range(n)),
|
||||
"timestamp": timestamps,
|
||||
"task_index": [0] * n,
|
||||
"subtask_index": [0] * n,
|
||||
}
|
||||
)
|
||||
pq.write_table(table, data_dir / "file-000.parquet")
|
||||
meta = root / "meta"
|
||||
meta.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(
|
||||
pa.Table.from_pydict({"task_index": [0], "task": ["Pour water into the cup."]}),
|
||||
meta / "tasks.parquet",
|
||||
)
|
||||
(meta / "info.json").write_text(json.dumps({"codebase_version": "v3.1", "fps": 10}))
|
||||
return root
|
||||
|
||||
|
||||
def _stub_responder(messages):
|
||||
text = ""
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
content = m.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
text = content
|
||||
if "Decompose the demonstration" in text:
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 1.0},
|
||||
{"text": "pour into the cup", "start": 1.0, "end": 2.0},
|
||||
{"text": "place the bottle down", "start": 2.0, "end": 3.0},
|
||||
]
|
||||
}
|
||||
if "concise hierarchical PLAN" in text:
|
||||
return {"plan": "1. grasp\n2. pour\n3. place"}
|
||||
if "Update the memory" in text:
|
||||
return {"memory": "poured once"}
|
||||
if "acknowledgement the robot" in text:
|
||||
return {"text": "Sure."}
|
||||
if "ONE realistic interruption" in text:
|
||||
return {"interjection": "use less water", "speech": "Using less water."}
|
||||
if "frame-grounded visual question" in text:
|
||||
return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}}
|
||||
return None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = _build_dataset(Path(tmp) / "ds")
|
||||
vlm = StubVlmClient(responder=_stub_responder)
|
||||
cfg = AnnotationPipelineConfig()
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1),
|
||||
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed),
|
||||
module_3=GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
summary = executor.run(root)
|
||||
print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}")
|
||||
print(f"validation: {summary.validation_report.summary()}")
|
||||
print(f"shards rewritten: {len(summary.written_paths)}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1/2/3 unit tests with stubbed VLMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
Module1Config,
|
||||
Module2Config,
|
||||
Module3Config,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
|
||||
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"Decompose the demonstration": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||
]
|
||||
},
|
||||
"write a concise hierarchical PLAN": {"plan": "1. grasp\n2. wipe\n3. place"},
|
||||
"Update the memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=Module1Config())
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_1")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert {"subtask", "plan", "memory"}.issubset(styles)
|
||||
# subtask timestamps must be exact frame timestamps
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for row in rows:
|
||||
assert row["timestamp"] in frame_set
|
||||
# exactly one plan row at t0
|
||||
plan_rows = [r for r in rows if r["style"] == "plan"]
|
||||
assert len(plan_rows) == 1
|
||||
assert plan_rows[0]["timestamp"] == record.frame_timestamps[0]
|
||||
|
||||
|
||||
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{"acknowledgement the robot": {"text": "Sure, on it."}},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=Module2Config(max_interjections_per_episode=0),
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_2")
|
||||
assert len(rows) == 1
|
||||
only = rows[0]
|
||||
assert only["role"] == "assistant"
|
||||
assert only["style"] is None
|
||||
assert only["content"] is None
|
||||
assert only["timestamp"] == record.frame_timestamps[0]
|
||||
assert only["tool_calls"][0]["function"]["name"] == "say"
|
||||
|
||||
|
||||
def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"acknowledgement the robot": {"text": "OK."},
|
||||
"ONE realistic interruption": {
|
||||
"interjection": "actually skip the dishes",
|
||||
"speech": "Skipping the dishes.",
|
||||
},
|
||||
},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||
seed=7,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_2")
|
||||
|
||||
interjections = [r for r in rows if r["style"] == "interjection"]
|
||||
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
|
||||
assert len(interjections) == 1
|
||||
assert len(speeches) >= 2 # initial t=0 + one paired with the interjection
|
||||
inter_t = interjections[0]["timestamp"]
|
||||
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||
|
||||
|
||||
def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=3),
|
||||
seed=1,
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_3")
|
||||
user_ts = [r["timestamp"] for r in rows if r["role"] == "user" and r["style"] == "vqa"]
|
||||
assistant_ts = [r["timestamp"] for r in rows if r["role"] == "assistant" and r["style"] == "vqa"]
|
||||
# at most one user (vqa) per frame; same for assistant
|
||||
assert len(user_ts) == len(set(user_ts))
|
||||
assert len(assistant_ts) == len(set(assistant_ts))
|
||||
# every emitted timestamp must be an exact source frame timestamp
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for ts in user_ts + assistant_ts:
|
||||
assert ts in frame_set
|
||||
|
||||
|
||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "Where is the cup?",
|
||||
"answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||
seed=2,
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_3")
|
||||
for row in rows:
|
||||
if row["role"] == "assistant" and row["style"] == "vqa":
|
||||
decoded = json.loads(row["content"])
|
||||
assert "detections" in decoded
|
||||
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end smoke: pipeline output → PR 1 canonical recipe rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
AnnotationPipelineConfig,
|
||||
Module1Config,
|
||||
Module2Config,
|
||||
Module3Config,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
_RECIPE_PATH = (
|
||||
Path(__file__).resolve().parents[2] / "src" / "lerobot" / "configs" / "recipes" / "pi05_hirobot.yaml"
|
||||
)
|
||||
|
||||
|
||||
def _build_executor() -> Executor:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"Decompose the demonstration": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 0.5},
|
||||
{"text": "pour into the cup", "start": 0.5, "end": 1.0},
|
||||
{"text": "place the bottle down", "start": 1.0, "end": 1.5},
|
||||
]
|
||||
},
|
||||
"write a concise hierarchical PLAN": {"plan": "1. grasp\n2. pour\n3. place"},
|
||||
"Update the memory": {"memory": "poured once"},
|
||||
"acknowledgement the robot": {"text": "Sure."},
|
||||
"ONE realistic interruption": {
|
||||
"interjection": "use less water",
|
||||
"speech": "Using less water.",
|
||||
},
|
||||
"frame-grounded visual question": {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
config = AnnotationPipelineConfig(
|
||||
module_1=Module1Config(),
|
||||
module_2=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||
module_3=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||
)
|
||||
return Executor(
|
||||
config=config,
|
||||
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=config.module_1),
|
||||
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=config.module_2, seed=config.seed),
|
||||
module_3=GeneralVqaModule(vlm=vlm, config=config.module_3, seed=config.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
|
||||
def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
|
||||
single_episode_root: Path,
|
||||
) -> None:
|
||||
executor = _build_executor()
|
||||
summary = executor.run(single_episode_root)
|
||||
# validator may emit warnings but no errors for the synthetic fixture
|
||||
assert summary.validation_report.ok, summary.validation_report.summary()
|
||||
|
||||
table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent_lists = table.column("language_persistent").to_pylist()
|
||||
events_lists = table.column("language_events").to_pylist()
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
|
||||
recipe = TrainingRecipe.from_yaml(_RECIPE_PATH) if hasattr(TrainingRecipe, "from_yaml") else None
|
||||
if recipe is None:
|
||||
# PR 1 may not expose from_yaml; load via PyYAML and TrainingRecipe(**...)
|
||||
import yaml
|
||||
|
||||
loaded = yaml.safe_load(_RECIPE_PATH.read_text(encoding="utf-8"))
|
||||
recipe = TrainingRecipe(**loaded)
|
||||
|
||||
rendered_any = False
|
||||
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=float(ts),
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "Pour water from the bottle into the cup."},
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
if result["messages"]:
|
||||
rendered_any = True
|
||||
assert result["target_message_indices"]
|
||||
break
|
||||
assert rendered_any, "PR 1 recipe rendered no messages from pipeline output"
|
||||
|
||||
# Sanity: speech atom appears in events column intact
|
||||
flat_events = [r for ev in events_lists for r in ev]
|
||||
speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"]
|
||||
assert speech_rows
|
||||
say = speech_rows[0]["tool_calls"][0]
|
||||
assert say["function"]["name"] == "say"
|
||||
assert isinstance(say["function"]["arguments"]["text"], str)
|
||||
# Tools column carries the say schema
|
||||
tools = json.loads(table.column("tools").to_pylist()[0])
|
||||
assert tools and tools[0]["function"]["name"] == "say"
|
||||
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Validator behavior tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.writer import speech_atom
|
||||
|
||||
|
||||
def _validate(root: Path, staging_dir: Path):
|
||||
records = list(iter_episodes(root))
|
||||
return StagingValidator().validate(records, staging_dir)
|
||||
|
||||
|
||||
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_3",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({"label": "cup", "count": 2}, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": 9.999, # not on any 10 fps frame
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("does not match any source frame timestamp" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_2",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
# interjection at 0.3s with NO paired speech
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip it",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("paired speech" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_1",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do x",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do x",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_2",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
speech_atom(0.4, "Replanning."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "replan",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
# missing co-timestamped plan refresh at 0.4s → error
|
||||
assert not report.ok
|
||||
assert any("co-timestamped plan update" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_1",
|
||||
[
|
||||
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("module_1 emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Writer correctness tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.writer import (
|
||||
LanguageColumnsWriter,
|
||||
speech_atom,
|
||||
)
|
||||
|
||||
|
||||
def _stage_episode(
|
||||
staging_dir: Path,
|
||||
episode_index: int,
|
||||
*,
|
||||
module_1: list[dict] | None = None,
|
||||
module_2: list[dict] | None = None,
|
||||
module_3: list[dict] | None = None,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, episode_index)
|
||||
if module_1 is not None:
|
||||
staging.write("module_1", module_1)
|
||||
if module_2 is not None:
|
||||
staging.write("module_2", module_2)
|
||||
if module_3 is not None:
|
||||
staging.write("module_3", module_3)
|
||||
|
||||
|
||||
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Every frame in an episode has a byte-identical persistent list."""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. wipe\n2. dry",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wiped the counter",
|
||||
"style": "memory",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent = table.column("language_persistent").to_pylist()
|
||||
first = persistent[0]
|
||||
assert first # non-empty
|
||||
for row in persistent:
|
||||
assert row == first, "persistent slice must be byte-identical across all frames"
|
||||
|
||||
|
||||
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_2=[
|
||||
speech_atom(0.0, "Got it."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip the dishes",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.5, "Skipping the dishes."),
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
events = table.column("language_events").to_pylist()
|
||||
for ts, ev in zip(timestamps, events, strict=True):
|
||||
if abs(ts - 0.0) < 1e-9:
|
||||
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
|
||||
elif abs(ts - 0.5) < 1e-9:
|
||||
assert any(r.get("style") == "interjection" for r in ev), ev
|
||||
assert any(r.get("style") is None for r in ev), ev
|
||||
else:
|
||||
assert ev == []
|
||||
|
||||
|
||||
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do X",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "did X",
|
||||
"style": "memory",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
module_2=[
|
||||
speech_atom(0.0, "OK"),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "wait",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.2,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.2, "Waiting"),
|
||||
],
|
||||
module_3=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "where is the cup?",
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(
|
||||
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
|
||||
sort_keys=True,
|
||||
),
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
|
||||
persistent = table.column("language_persistent").to_pylist()[0]
|
||||
persistent_styles = {r["style"] for r in persistent}
|
||||
assert persistent_styles == {"subtask", "plan", "memory"}
|
||||
|
||||
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
|
||||
event_styles = {r.get("style") for r in all_events}
|
||||
assert event_styles == {None, "interjection", "vqa"}
|
||||
|
||||
|
||||
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
writer = LanguageColumnsWriter()
|
||||
writer.write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
|
||||
table_a = pq.read_table(path)
|
||||
assert "subtask_index" not in table_a.column_names
|
||||
assert "language_persistent" in table_a.column_names
|
||||
assert "language_events" in table_a.column_names
|
||||
assert "tools" in table_a.column_names
|
||||
|
||||
# second pass — must produce identical bytes for the language columns
|
||||
records_again = list(iter_episodes(fixture_dataset_root))
|
||||
writer.write_all(records_again, staging_dir, fixture_dataset_root)
|
||||
table_b = pq.read_table(path)
|
||||
assert (
|
||||
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
|
||||
)
|
||||
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
|
||||
"""``_normalize_persistent_row`` must reject any non-persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
|
||||
|
||||
with pytest.raises(ValueError, match="non-persistent style"):
|
||||
_normalize_persistent_row(
|
||||
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
|
||||
)
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_event_style() -> None:
|
||||
"""``_normalize_event_row`` must reject any persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
|
||||
|
||||
|
||||
def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
tools = table.column("tools").to_pylist()
|
||||
assert tools, "tools column missing"
|
||||
decoded = json.loads(tools[0])
|
||||
assert isinstance(decoded, list)
|
||||
assert len(decoded) == 1
|
||||
assert decoded[0]["function"]["name"] == "say"
|
||||
params = decoded[0]["function"]["parameters"]
|
||||
assert params["properties"]["text"]["type"] == "string"
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
assert atom["style"] is None
|
||||
assert atom["content"] is None
|
||||
assert atom["timestamp"] == 2.5
|
||||
assert isinstance(atom["tool_calls"], list)
|
||||
call = atom["tool_calls"][0]
|
||||
assert call["type"] == "function"
|
||||
assert call["function"]["name"] == "say"
|
||||
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"
|
||||
Reference in New Issue
Block a user