#!/usr/bin/env python # Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Writer correctness tests.""" from __future__ import annotations import json from pathlib import Path import pyarrow.parquet as pq import pytest from lerobot.annotations.steerable_pipeline.reader import iter_episodes from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging from lerobot.annotations.steerable_pipeline.writer import ( LanguageColumnsWriter, speech_atom, ) def _stage_episode( staging_dir: Path, episode_index: int, *, module_1: list[dict] | None = None, module_2: list[dict] | None = None, module_3: list[dict] | None = None, ) -> None: staging = EpisodeStaging(staging_dir, episode_index) if module_1 is not None: staging.write("module_1", module_1) if module_2 is not None: staging.write("module_2", module_2) if module_3 is not None: staging.write("module_3", module_3) def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None: """Every frame in an episode has a byte-identical persistent list.""" staging_dir = tmp_path / "stage" _stage_episode( staging_dir, 0, module_1=[ { "role": "assistant", "content": "grasp the sponge", "style": "subtask", "timestamp": 0.0, "tool_calls": None, }, { "role": "assistant", "content": "1. wipe\n2. dry", "style": "plan", "timestamp": 0.0, "tool_calls": None, }, { "role": "assistant", "content": "wiped the counter", "style": "memory", "timestamp": 0.5, "tool_calls": None, }, ], ) records = list(iter_episodes(fixture_dataset_root)) LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") persistent = table.column("language_persistent").to_pylist() first = persistent[0] assert first # non-empty for row in persistent: assert row == first, "persistent slice must be byte-identical across all frames" def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None: staging_dir = tmp_path / "stage" _stage_episode( staging_dir, 0, module_2=[ speech_atom(0.0, "Got it."), { "role": "user", "content": "skip the dishes", "style": "interjection", "timestamp": 0.5, "tool_calls": None, }, speech_atom(0.5, "Skipping the dishes."), ], ) records = list(iter_episodes(fixture_dataset_root)) LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") timestamps = table.column("timestamp").to_pylist() events = table.column("language_events").to_pylist() for ts, ev in zip(timestamps, events, strict=True): if abs(ts - 0.0) < 1e-9: assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev elif abs(ts - 0.5) < 1e-9: assert any(r.get("style") == "interjection" for r in ev), ev assert any(r.get("style") is None for r in ev), ev else: assert ev == [] def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None: staging_dir = tmp_path / "stage" _stage_episode( staging_dir, 0, module_1=[ { "role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None, }, { "role": "assistant", "content": "1. do X", "style": "plan", "timestamp": 0.0, "tool_calls": None, }, { "role": "assistant", "content": "did X", "style": "memory", "timestamp": 0.3, "tool_calls": None, }, ], module_2=[ speech_atom(0.0, "OK"), { "role": "user", "content": "wait", "style": "interjection", "timestamp": 0.2, "tool_calls": None, }, speech_atom(0.2, "Waiting"), ], module_3=[ { "role": "user", "content": "where is the cup?", "style": "vqa", "timestamp": 0.4, "tool_calls": None, }, { "role": "assistant", "content": json.dumps( {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]}, sort_keys=True, ), "style": "vqa", "timestamp": 0.4, "tool_calls": None, }, ], ) records = list(iter_episodes(fixture_dataset_root)) LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") persistent = table.column("language_persistent").to_pylist()[0] persistent_styles = {r["style"] for r in persistent} assert persistent_styles == {"subtask", "plan", "memory"} all_events = [r for ev in table.column("language_events").to_pylist() for r in ev] event_styles = {r.get("style") for r in all_events} assert event_styles == {None, "interjection", "vqa"} def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None: staging_dir = tmp_path / "stage" _stage_episode( staging_dir, 0, module_1=[ { "role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None, }, ], ) records = list(iter_episodes(fixture_dataset_root)) writer = LanguageColumnsWriter() writer.write_all(records, staging_dir, fixture_dataset_root) path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet" table_a = pq.read_table(path) assert "subtask_index" not in table_a.column_names assert "language_persistent" in table_a.column_names assert "language_events" in table_a.column_names assert "tools" in table_a.column_names # second pass — must produce identical bytes for the language columns records_again = list(iter_episodes(fixture_dataset_root)) writer.write_all(records_again, staging_dir, fixture_dataset_root) table_b = pq.read_table(path) assert ( table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist() ) assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist() def test_writer_normalize_rejects_misrouted_persistent_style() -> None: """``_normalize_persistent_row`` must reject any non-persistent style.""" from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row with pytest.raises(ValueError, match="non-persistent style"): _normalize_persistent_row( {"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None} ) def test_writer_normalize_rejects_misrouted_event_style() -> None: """``_normalize_event_row`` must reject any persistent style.""" from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row with pytest.raises(ValueError): _normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None}) def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path, tmp_path: Path) -> None: staging_dir = tmp_path / "stage" _stage_episode( staging_dir, 0, module_1=[ {"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None} ], ) records = list(iter_episodes(fixture_dataset_root)) LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") tools = table.column("tools").to_pylist() assert tools, "tools column missing" decoded = json.loads(tools[0]) assert isinstance(decoded, list) assert len(decoded) == 1 assert decoded[0]["function"]["name"] == "say" params = decoded[0]["function"]["parameters"] assert params["properties"]["text"]["type"] == "string" def test_speech_atom_shape_matches_plan_spec() -> None: atom = speech_atom(2.5, "I'm cleaning up!") assert atom["role"] == "assistant" assert atom["style"] is None assert atom["content"] is None assert atom["timestamp"] == 2.5 assert isinstance(atom["tool_calls"], list) call = atom["tool_calls"][0] assert call["type"] == "function" assert call["function"]["name"] == "say" assert call["function"]["arguments"]["text"] == "I'm cleaning up!"