Files
lerobot/tests/annotations/test_writer.py
T
Pepijn 58ccc01508 fix(datasets): enforce one parquet row group per episode in v3 data writes (#3807)
* fix(datasets): enforce one parquet row group per episode in v3 data writes

LeRobot v3 data shards must hold exactly one row group per episode so a
reader can fetch episode i with pq.ParquetFile(path).read_row_group(i)
(a byte-range read) instead of loading the whole shard. The recording
writer already does this (one write_table per episode); the aggregate
and lerobot-annotate re-write paths instead concatenated many episodes
and wrote them in one shot, collapsing the file to a single row group.

- io_utils: add write_table_one_row_group_per_episode (one ParquetWriter,
  one write_table per episode — same pattern as the recording writer);
  to_parquet_with_hf_images embeds images then writes per-episode row
  groups; to_parquet_one_row_group_per_episode wraps it for plain frames
- aggregate: route non-image data writes through the per-episode writer;
  leave the episodes-metadata parquet untouched (already one row/episode)
- annotate: rewrite shards via the per-episode writer instead of a single
  bulk pq.write_table
- tests: invariant coverage through the aggregate (image + video) and
  annotate paths

No change to on-disk schema, paths, naming, rollover thresholds, or
compression. Readers stay backward-compatible (old collapsed files load).

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(datasets): correct indentation and add strict= in row-group helper

The web-edited numpy version of write_table_one_row_group_per_episode had an
over-indented line (IndentationError, breaking pre-commit + test collection)
and a zip() without strict=. Fix both; behaviour unchanged.

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-16 12:15:48 +02:00

431 lines
16 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Writer correctness tests."""
from __future__ import annotations
import json
from pathlib import Path
import pytest
# ``pyarrow`` and the ``lerobot.annotations`` -> ``lerobot.datasets`` chain
# (-> the HF ``datasets`` library) only ship under the ``dataset`` extra.
# Skip this module in tiers without it instead of erroring at import.
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging # noqa: E402
from lerobot.annotations.steerable_pipeline.writer import ( # noqa: E402
LanguageColumnsWriter,
speech_atom,
)
def _stage_episode(
staging_dir: Path,
episode_index: int,
*,
plan: list[dict] | None = None,
interjections: list[dict] | None = None,
vqa: list[dict] | None = None,
) -> None:
staging = EpisodeStaging(staging_dir, episode_index)
if plan is not None:
staging.write("plan", plan)
if interjections is not None:
staging.write("interjections", interjections)
if vqa is not None:
staging.write("vqa", vqa)
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
"""Every frame in an episode has a byte-identical persistent list."""
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
plan=[
{
"role": "assistant",
"content": "grasp the sponge",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "1. wipe\n2. dry",
"style": "plan",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "wiped the counter",
"style": "memory",
"timestamp": 0.5,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
persistent = table.column("language_persistent").to_pylist()
first = persistent[0]
assert first # non-empty
for row in persistent:
assert row == first, "persistent slice must be byte-identical across all frames"
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
interjections=[
speech_atom(0.0, "Got it."),
{
"role": "user",
"content": "skip the dishes",
"style": "interjection",
"timestamp": 0.5,
"tool_calls": None,
},
speech_atom(0.5, "Skipping the dishes."),
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
timestamps = table.column("timestamp").to_pylist()
events = table.column("language_events").to_pylist()
for ts, ev in zip(timestamps, events, strict=True):
if abs(ts - 0.0) < 1e-9:
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
elif abs(ts - 0.5) < 1e-9:
assert any(r.get("style") == "interjection" for r in ev), ev
assert any(r.get("style") is None for r in ev), ev
else:
assert ev == []
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
plan=[
{
"role": "assistant",
"content": "do X",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "1. do X",
"style": "plan",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "did X",
"style": "memory",
"timestamp": 0.3,
"tool_calls": None,
},
],
interjections=[
speech_atom(0.0, "OK"),
{
"role": "user",
"content": "wait",
"style": "interjection",
"timestamp": 0.2,
"tool_calls": None,
},
speech_atom(0.2, "Waiting"),
],
vqa=[
{
"role": "user",
"content": "where is the cup?",
"style": "vqa",
"timestamp": 0.4,
"camera": "observation.images.front",
"tool_calls": None,
},
{
"role": "assistant",
"content": json.dumps(
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
sort_keys=True,
),
"style": "vqa",
"timestamp": 0.4,
"camera": "observation.images.front",
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
persistent = table.column("language_persistent").to_pylist()[0]
persistent_styles = {r["style"] for r in persistent}
assert persistent_styles == {"subtask", "plan", "memory"}
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
event_styles = {r.get("style") for r in all_events}
assert event_styles == {None, "interjection", "vqa"}
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
plan=[
{
"role": "assistant",
"content": "do X",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
writer = LanguageColumnsWriter()
writer.write_all(records, staging_dir, fixture_dataset_root)
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
table_a = pq.read_table(path)
assert "subtask_index" not in table_a.column_names
assert "language_persistent" in table_a.column_names
assert "language_events" in table_a.column_names
# The writer no longer emits a dataset-level ``tools`` column; the
# ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``)
# so the parquet stays small and the pipeline doesn't extend the schema.
assert "tools" not in table_a.column_names
# second pass — must produce identical bytes for the language columns
records_again = list(iter_episodes(fixture_dataset_root))
writer.write_all(records_again, staging_dir, fixture_dataset_root)
table_b = pq.read_table(path)
assert (
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
)
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
"""``_normalize_persistent_row`` must reject any non-persistent style."""
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
with pytest.raises(ValueError, match="non-persistent style"):
_normalize_persistent_row(
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
)
def test_writer_normalize_rejects_misrouted_event_style() -> None:
"""``_normalize_event_row`` must reject any persistent style."""
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
with pytest.raises(ValueError):
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
def test_say_tool_schema_constant_is_well_formed() -> None:
"""``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet
``tools`` column — chat-template consumers import them directly.
"""
from lerobot.annotations.steerable_pipeline.writer import (
DEFAULT_TOOLS,
SAY_TOOL_SCHEMA,
)
assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA]
assert SAY_TOOL_SCHEMA["function"]["name"] == "say"
params = SAY_TOOL_SCHEMA["function"]["parameters"]
assert params["properties"]["text"]["type"] == "string"
assert params["required"] == ["text"]
def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
"""Re-running on a parquet that already has a legacy ``tools`` column
must drop it cleanly so reruns converge to the v3.1 schema.
"""
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
plan=[
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
assert "tools" not in table.column_names
def test_annotation_metadata_sync_allows_non_streaming_load(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
"""Annotated parquet columns must be declared in ``meta/info.json``.
``LeRobotDataset`` loads non-streaming datasets by casting parquet
against metadata-derived HF features. If the annotation writer adds
language columns but metadata stays stale, that cast fails with a column
mismatch.
"""
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import load_info, load_nested_dataset
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info
info_path = fixture_dataset_root / "meta" / "info.json"
info = json.loads(info_path.read_text())
info["features"] = {
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
info_path.write_text(json.dumps(info, indent=2))
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
plan=[
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
Executor._ensure_annotation_metadata_in_info(fixture_dataset_root)
synced = load_info(fixture_dataset_root)
for key, feature in language_feature_info().items():
assert synced["features"][key] == feature
hf_features = get_hf_features_from_features(synced["features"])
dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features)
assert LANGUAGE_PERSISTENT in dataset.column_names
assert LANGUAGE_EVENTS in dataset.column_names
assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant"
assert atom["style"] is None
assert atom["content"] is None
assert atom["timestamp"] == 2.5
assert isinstance(atom["tool_calls"], list)
call = atom["tool_calls"][0]
assert call["type"] == "function"
assert call["function"]["name"] == "say"
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"