Files
lerobot/tests/annotations/test_writer.py
T
Pepijn a635a32290 feat: language annotation pipeline (PR 2/3)
Adds the steerable annotation pipeline (`lerobot-annotate`) that populates
the `language_persistent` and `language_events` columns introduced in
PR 1 directly into `data/chunk-*/file-*.parquet`. No flavor namespace,
no sidecar tree.

Modules produced:
- Module 1 (plan_subtasks_memory): Pi0.7-style subtasks, plan (init +
  refresh on interjection), MEM-style memory at subtask boundaries.
- Module 2 (interjections_and_speech): t=0 speech-only acknowledgement,
  mid-episode paired interjection + speech tool-call atom.
- Module 3 (general_vqa): bbox/keypoint/count/attribute/spatial pairs at
  configurable cadence with one-retry JSON validation.

Writer enforces: per-episode persistent identity, exact-frame event
timestamps, column routing per `column_for_style`, dataset-level `tools`
column with the `say` schema, drops legacy `subtask_index`. Validator
runs against staged JSONL artifacts before the writer rewrites parquet.

Adds `lerobot-annotate` console script, `annotations` extra (datatrove +
optional vllm), `make annotation-e2e` opt-in smoke target, and
`docs/source/annotation_pipeline.mdx`.

Branched from PR 1 (`feat/language-columns`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 16:22:51 +02:00

284 lines
10 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Writer correctness tests."""
from __future__ import annotations
import json
from pathlib import Path
import pyarrow.parquet as pq
import pytest
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
from lerobot.annotations.steerable_pipeline.writer import (
LanguageColumnsWriter,
speech_atom,
)
def _stage_episode(
staging_dir: Path,
episode_index: int,
*,
module_1: list[dict] | None = None,
module_2: list[dict] | None = None,
module_3: list[dict] | None = None,
) -> None:
staging = EpisodeStaging(staging_dir, episode_index)
if module_1 is not None:
staging.write("module_1", module_1)
if module_2 is not None:
staging.write("module_2", module_2)
if module_3 is not None:
staging.write("module_3", module_3)
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
"""Every frame in an episode has a byte-identical persistent list."""
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{
"role": "assistant",
"content": "grasp the sponge",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "1. wipe\n2. dry",
"style": "plan",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "wiped the counter",
"style": "memory",
"timestamp": 0.5,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
persistent = table.column("language_persistent").to_pylist()
first = persistent[0]
assert first # non-empty
for row in persistent:
assert row == first, "persistent slice must be byte-identical across all frames"
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_2=[
speech_atom(0.0, "Got it."),
{
"role": "user",
"content": "skip the dishes",
"style": "interjection",
"timestamp": 0.5,
"tool_calls": None,
},
speech_atom(0.5, "Skipping the dishes."),
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
timestamps = table.column("timestamp").to_pylist()
events = table.column("language_events").to_pylist()
for ts, ev in zip(timestamps, events, strict=True):
if abs(ts - 0.0) < 1e-9:
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
elif abs(ts - 0.5) < 1e-9:
assert any(r.get("style") == "interjection" for r in ev), ev
assert any(r.get("style") is None for r in ev), ev
else:
assert ev == []
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{
"role": "assistant",
"content": "do X",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "1. do X",
"style": "plan",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "did X",
"style": "memory",
"timestamp": 0.3,
"tool_calls": None,
},
],
module_2=[
speech_atom(0.0, "OK"),
{
"role": "user",
"content": "wait",
"style": "interjection",
"timestamp": 0.2,
"tool_calls": None,
},
speech_atom(0.2, "Waiting"),
],
module_3=[
{
"role": "user",
"content": "where is the cup?",
"style": "vqa",
"timestamp": 0.4,
"tool_calls": None,
},
{
"role": "assistant",
"content": json.dumps(
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
sort_keys=True,
),
"style": "vqa",
"timestamp": 0.4,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
persistent = table.column("language_persistent").to_pylist()[0]
persistent_styles = {r["style"] for r in persistent}
assert persistent_styles == {"subtask", "plan", "memory"}
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
event_styles = {r.get("style") for r in all_events}
assert event_styles == {None, "interjection", "vqa"}
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{
"role": "assistant",
"content": "do X",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
writer = LanguageColumnsWriter()
writer.write_all(records, staging_dir, fixture_dataset_root)
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
table_a = pq.read_table(path)
assert "subtask_index" not in table_a.column_names
assert "language_persistent" in table_a.column_names
assert "language_events" in table_a.column_names
assert "tools" in table_a.column_names
# second pass — must produce identical bytes for the language columns
records_again = list(iter_episodes(fixture_dataset_root))
writer.write_all(records_again, staging_dir, fixture_dataset_root)
table_b = pq.read_table(path)
assert (
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
)
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
"""``_normalize_persistent_row`` must reject any non-persistent style."""
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
with pytest.raises(ValueError, match="non-persistent style"):
_normalize_persistent_row(
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
)
def test_writer_normalize_rejects_misrouted_event_style() -> None:
"""``_normalize_event_row`` must reject any persistent style."""
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
with pytest.raises(ValueError):
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
tools = table.column("tools").to_pylist()
assert tools, "tools column missing"
decoded = json.loads(tools[0])
assert isinstance(decoded, list)
assert len(decoded) == 1
assert decoded[0]["function"]["name"] == "say"
params = decoded[0]["function"]["parameters"]
assert params["properties"]["text"]["type"] == "string"
def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant"
assert atom["style"] is None
assert atom["content"] is None
assert atom["timestamp"] == 2.5
assert isinstance(atom["tool_calls"], list)
call = atom["tool_calls"][0]
assert call["type"] == "function"
assert call["function"]["name"] == "say"
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"