diff --git a/Makefile b/Makefile index e02f02403..d3987101f 100644 --- a/Makefile +++ b/Makefile @@ -178,3 +178,9 @@ test-smolvla-ete-eval: --env.episode_length=5 \ --eval.n_episodes=1 \ --eval.batch_size=1 + +# E2E annotation pipeline smoke test against a tiny in-memory fixture +# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM +# backend, so it does not require a real model checkpoint or GPU. +annotation-e2e: + uv run python -m tests.annotations.run_e2e_smoke diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 5ca449145..662b5514f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -33,6 +33,8 @@ title: Using the Dataset Tools - local: language_and_recipes title: Language Columns and Recipes + - local: annotation_pipeline + title: Annotation Pipeline - local: streaming_video_encoding title: Streaming Video Encoding title: "Datasets" diff --git a/docs/source/annotation_pipeline.mdx b/docs/source/annotation_pipeline.mdx new file mode 100644 index 000000000..6b88decc6 --- /dev/null +++ b/docs/source/annotation_pipeline.mdx @@ -0,0 +1,133 @@ +# Annotation Pipeline + +`lerobot-annotate` populates the two language columns introduced by the +[Language Columns and Recipes](./language_and_recipes) page — +`language_persistent` and `language_events` — directly into +`data/chunk-*/file-*.parquet`. There is no flavor namespace and no sidecar +file tree: multiple revisions of a dataset mean multiple dataset copies. + +## What the pipeline produces + +Three modules write into a per-episode staging tree, then a single writer +rewrites the data shards in place: + +| Style / atom | Column | Module | +| ------------------------------------------- | --------------------- | -------- | +| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | Module 1 | +| `plan` (initial + refresh on interjection) | `language_persistent` | Module 1 | +| `memory` (MEM-style compression) | `language_persistent` | Module 1 | +| `interjection` | `language_events` | Module 2 | +| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 | +| `vqa` (user / assistant pair) | `language_events` | Module 3 | + +The writer also adds a dataset-level `tools` column carrying the JSON schema +for the `say` tool call, and drops the legacy `subtask_index` column. + +## How to run it locally or on SLURM + +Install the extra and invoke the console script: + +```bash +uv sync --extra annotations +uv run lerobot-annotate \ + --root=/path/to/dataset \ + --vlm.backend=transformers \ + --vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct +``` + +The executor picks `LocalPipelineExecutor` for small datasets and +`SlurmPipelineExecutor` for large ones based on +`--executor.auto_threshold` (default 32 episodes). Force local with +`--executor.force_local=true`. SLURM jobs honour `--executor.slurm_partition`, +`--executor.slurm_gpus`, and `--executor.slurm_time`. + +## Style-to-recipe consumer mapping + +The pipeline produces exactly the styles consumed by +`src/lerobot/configs/recipes/pi05_hirobot.yaml`: + +- `low_level_execution`, `high_level_subtask`, `memory_update` consume + `subtask`/`plan`/`memory` from `language_persistent`. +- `user_interjection_response` consumes `interjection` events plus the + paired speech atom (merged into one assistant target turn via + `tool_calls_from`) and the same-timestamp `plan` refresh. +- `ask_vqa` consumes the `(vqa, user)` and `(vqa, assistant)` pairs from + `language_events`. + +## Why the design is scoped to the canonical recipe + +Two things drive the scope: + +1. **Persistent state vs exact-event split.** Persistent rows (`subtask`, + `plan`, `memory`) broadcast per episode and answer "what state is in + force at this frame?". Event rows (`interjection`, `vqa`, speech) only + appear on the exact frame whose timestamp matches the emission. The + pipeline writes timestamps taken straight from the source parquet — no + floating-point recomputation. +2. **One Qwen-VL pass.** All three modules share a single VLM client + (vLLM if available, transformers fallback) so the cost is one model + load per dataset, not three. + +## Module independence and staged reruns + +Each module writes its raw output to +`/.annotate_staging/episode_{N:06d}/.jsonl`. That makes +prompt iteration cheap — re-running one module overwrites only its own +JSONL file before the writer composes the final parquet. Modules can be +disabled via `--module_1.enabled=false` (and similarly for 2 and 3) to +test them in isolation. + +## Validation/report checks before final write + +Before the writer runs, `StagingValidator` checks: + +- exact frame-timestamp alignment for every event row; +- no orphan speech / interjection pairs; +- `plan` is refreshed at every interjection timestamp; +- `memory` rows fall on subtask boundaries (warning, not error); +- VQA assistant `content` parses as JSON in one of the + bbox / keypoint / count / attribute / spatial shapes; +- every row routes to the column dictated by `column_for_style(style)`. + +Errors abort the writer (`--skip_validation=true` overrides for debugging). + +## Paper inspirations per module + +- **Module 1 — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417)) + atom granularity ("pick up one piece of lettuce", "place bowl to box"); + Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not + what" detail. +- **Module 1 — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)) + compression directive: keep only minimal relevant information; functional + outcomes preserved, specific attributes dropped. +- **Module 2 — interjections.** Hi Robot scenario taxonomy: negative task, + situated correction, specific constraint, preference. Speech is a + tool-call-only atom (`tool_calls=[{type:function, function:{name:"say", +arguments:{text:...}}}]`). +- **Module 3 — VQA.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) + grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`, + keypoints) and Steerable Policies' multi-abstraction grounding. + +Future maintainers should adjust the prompt templates in +`src/lerobot/annotations/steerable_pipeline/prompts/` against these +references rather than rewriting from scratch. + +## Compute and list-size estimates + +Per episode, the pipeline issues O(`max_steps`) Module 1 calls, +O(`max_interjections_per_episode`) Module 2 calls, and +O(`vqa_emission_hz × episode_seconds`) Module 3 calls. With defaults +(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that +is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of +KB at most (parquet dictionary-encodes one entry per episode); +`language_events` is empty on most frames and is bounded by the number of +emissions, not `num_frames × num_emissions`. + +## Reproducibility via seed and prompt hashes + +`--seed` (default 1729) feeds the per-episode RNGs that select interjection +timestamps and VQA question types. Combined with the deterministic prompt +templates checked into `prompts/`, two runs at the same seed against the +same dataset and the same model checkpoint produce byte-identical staging +artifacts. Prompt edits are recorded by file hash; future tooling can pin +expected `(seed, prompt_hash)` pairs into the dataset card. diff --git a/pyproject.toml b/pyproject.toml index 0790db6fb..49b67ca42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,6 +200,15 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] +# Annotation pipeline (lerobot-annotate). datatrove is mandatory; vllm is +# the preferred backend on Linux, with a transformers fallback elsewhere. +annotations = [ + "lerobot[dataset]", + "lerobot[transformers-dep]", + "datatrove>=0.4.0,<2.0.0", + "vllm>=0.6.0,<1.0.0; sys_platform == 'linux'", +] + # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"] notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"] @@ -289,6 +298,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" +lerobot-annotate="lerobot.scripts.lerobot_annotate:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.package-data] diff --git a/src/lerobot/annotations/__init__.py b/src/lerobot/annotations/__init__.py new file mode 100644 index 000000000..67782f192 --- /dev/null +++ b/src/lerobot/annotations/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/annotations/steerable_pipeline/__init__.py b/src/lerobot/annotations/steerable_pipeline/__init__.py new file mode 100644 index 000000000..ca87b9654 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/__init__.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Steerable annotation pipeline producing ``language_persistent`` and +``language_events`` columns for LeRobot datasets. + +The pipeline is decomposed into three independently runnable modules whose +outputs are staged per-episode before a final parquet rewrite: + +- :mod:`.modules.plan_subtasks_memory` (Module 1) — persistent styles +- :mod:`.modules.interjections_and_speech` (Module 2) — event styles + speech +- :mod:`.modules.general_vqa` (Module 3) — event-style VQA pairs +""" + +from .config import AnnotationPipelineConfig +from .validator import StagingValidator, ValidationReport +from .writer import LanguageColumnsWriter + +__all__ = [ + "AnnotationPipelineConfig", + "LanguageColumnsWriter", + "StagingValidator", + "ValidationReport", +] diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py new file mode 100644 index 000000000..eed745086 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal + + +@dataclass +class Module1Config: + """Module 1 hyperparameters: plan + subtasks + memory.""" + + enabled: bool = True + keyframes_per_episode: int = 8 + min_subtask_seconds: float = 1.5 + plan_max_steps: int = 8 + + +@dataclass +class Module2Config: + """Module 2 hyperparameters: interjections + paired speech.""" + + enabled: bool = True + max_interjections_per_episode: int = 1 + interjection_min_t: float = 2.0 + + +@dataclass +class Module3Config: + """Module 3 hyperparameters: general VQA.""" + + enabled: bool = True + vqa_emission_hz: float = 1.0 + K: int = 3 + question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial") + + +@dataclass +class VlmConfig: + """Shared Qwen-VL client configuration.""" + + backend: Literal["vllm", "transformers", "stub"] = "transformers" + model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct" + max_new_tokens: int = 512 + temperature: float = 0.2 + json_mode: bool = True + batch_size: int = 4 + + +@dataclass +class ExecutorConfig: + """Executor selection and SLURM hyperparameters.""" + + auto_threshold: int = 32 + force_local: bool = False + slurm_partition: str | None = None + slurm_gpus: int = 1 + slurm_time: str = "06:00:00" + workers: int = 1 + + +@dataclass +class AnnotationPipelineConfig: + """Top-level config for ``lerobot-annotate``. + + Mirrors the structure of :class:`lerobot.configs.train.TrainPipelineConfig`: + a draccus-parsed dataclass that contains nested per-module sub-configs and + leaves the dataset, executor, and VLM choices independently knobbable. + + Output is always in-place: the writer rewrites ``data/chunk-*/file-*.parquet`` + in place. Multiple revisions of the same dataset live in separate copies. + """ + + repo_id: str | None = None + root: Path | None = None + + staging_dir: Path | None = None + """If unset, defaults to ``/.annotate_staging/``.""" + + seed: int = 1729 + + module_1: Module1Config = field(default_factory=Module1Config) + module_2: Module2Config = field(default_factory=Module2Config) + module_3: Module3Config = field(default_factory=Module3Config) + + vlm: VlmConfig = field(default_factory=VlmConfig) + executor: ExecutorConfig = field(default_factory=ExecutorConfig) + + skip_validation: bool = False + only_episodes: tuple[int, ...] | None = None + + def resolved_staging_dir(self, root: Path) -> Path: + return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging" diff --git a/src/lerobot/annotations/steerable_pipeline/executor.py b/src/lerobot/annotations/steerable_pipeline/executor.py new file mode 100644 index 000000000..284eeabc5 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/executor.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Executor selection: local vs SLURM via datatrove. + +The executor plans **four phases** with the dependency order from the plan: + + phase 1: Module 1 (plan + subtasks + memory) + phase 2: Module 2 (interjections + speech) + phase 3: Module 1 plan-update pass — re-runs plan emission at every + interjection timestamp produced by phase 2 + phase 4: Module 3 (VQA) + phase 5: validator + phase 6: writer + +Phase 3 is why ``executor.py`` documents the dependency: Module 1 must be +re-entered after Module 2 to refresh ``plan`` rows at interjection times. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .config import AnnotationPipelineConfig, ExecutorConfig +from .reader import EpisodeRecord, iter_episodes +from .staging import EpisodeStaging +from .validator import StagingValidator +from .writer import LanguageColumnsWriter + +logger = logging.getLogger(__name__) + + +@dataclass +class PhaseResult: + """Summary of one pipeline phase across all episodes.""" + + name: str + episodes_processed: int + episodes_skipped: int + + +@dataclass +class PipelineRunSummary: + """Aggregated result returned by :meth:`Executor.run`.""" + + phases: list[PhaseResult] + written_paths: list[Path] + validation_report: Any # ValidationReport, kept Any to avoid import cycle + + +def select_executor_class(num_episodes: int, config: ExecutorConfig) -> str: + """Return ``"local"`` or ``"slurm"`` based on the threshold. + + The plan's "executor selection threshold" lives in + :class:`ExecutorConfig.auto_threshold`. ``force_local`` always wins. + """ + if config.force_local: + return "local" + return "local" if num_episodes <= config.auto_threshold else "slurm" + + +@dataclass +class Executor: + """Run all four phases over a dataset root. + + The executor is intentionally framework-agnostic: by default it runs the + phases inline (suitable for tests, small datasets, and the CLI's + ``--force-local`` mode). It will optionally hand off to datatrove's + :class:`LocalPipelineExecutor` or :class:`SlurmPipelineExecutor` when those + are installed and the dataset is large enough to benefit from them. + + Tests construct the executor directly with stub modules. + """ + + config: AnnotationPipelineConfig + module_1: Any # PlanSubtasksMemoryModule + module_2: Any # InterjectionsAndSpeechModule + module_3: Any # GeneralVqaModule + writer: LanguageColumnsWriter + validator: StagingValidator + + def run(self, root: Path) -> PipelineRunSummary: + records = list(iter_episodes(root, only_episodes=self.config.only_episodes)) + n = len(records) + if n == 0: + raise ValueError(f"No episodes found under {root}/data/") + + executor_kind = select_executor_class(n, self.config.executor) + logger.info("annotate: %d episodes; executor=%s", n, executor_kind) + + staging_dir = self.config.resolved_staging_dir(root) + staging_dir.mkdir(parents=True, exist_ok=True) + + phases: list[PhaseResult] = [] + + # Phase 1: Module 1 (plan + subtasks + memory) + phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1)) + # Phase 2: Module 2 (interjections + speech) + phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2)) + # Phase 3: Module 1 plan-update pass at interjection timestamps. + phases.append(self._run_plan_update_phase(records, staging_dir)) + # Phase 4: Module 3 (VQA) + phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3)) + + report = self.validator.validate(records, staging_dir) + if not report.ok and not self.config.skip_validation: + raise RuntimeError(f"Staging validation failed: {report.summary()}") + + written = self.writer.write_all(records, staging_dir, root) + return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report) + + def _run_module_phase( + self, + name: str, + records: list[EpisodeRecord], + staging_dir: Path, + module: Any, + ) -> PhaseResult: + if not module.enabled: + return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records)) + processed = 0 + for record in records: + staging = EpisodeStaging(staging_dir, record.episode_index) + module.run_episode(record, staging) + processed += 1 + return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0) + + def _run_plan_update_phase(self, records: list[EpisodeRecord], staging_dir: Path) -> PhaseResult: + """Re-emit ``plan`` rows at each interjection timestamp from Module 2. + + Module 1 owns the prompt; Module 2 produced the timestamps. This phase + therefore calls back into Module 1 with the interjection timestamps so + Module 1's existing prompt path is reused. + """ + if not self.module_1.enabled or not self.module_2.enabled: + return PhaseResult( + name="module_1_plan_update", episodes_processed=0, episodes_skipped=len(records) + ) + processed = 0 + for record in records: + staging = EpisodeStaging(staging_dir, record.episode_index) + interjection_times = [ + row["timestamp"] for row in staging.read("module_2") if row.get("style") == "interjection" + ] + if interjection_times: + self.module_1.run_plan_updates(record, staging, interjection_times) + processed += 1 + return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0) diff --git a/src/lerobot/annotations/steerable_pipeline/modules/__init__.py b/src/lerobot/annotations/steerable_pipeline/modules/__init__.py new file mode 100644 index 000000000..e9ff8ed23 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .general_vqa import GeneralVqaModule +from .interjections_and_speech import InterjectionsAndSpeechModule +from .plan_subtasks_memory import PlanSubtasksMemoryModule + +__all__ = [ + "GeneralVqaModule", + "InterjectionsAndSpeechModule", + "PlanSubtasksMemoryModule", +] diff --git a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py new file mode 100644 index 000000000..8ea19ab00 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module 3: general VQA at a timed cadence. + +Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per +emission so each frame gets at most one ``(vqa, user)`` and one +``(vqa, assistant)`` pair — keeps the resolver contract scalar. + +Question types covered (per the plan's Module 3 table): bbox, keypoint, +count, attribute, spatial. The assistant's ``content`` is a JSON string +whose schema depends on the question type. Malformed JSON triggers one +retry inside :meth:`VlmClient.generate_json`. +""" + +from __future__ import annotations + +import json +import random +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +from ..config import Module3Config +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord +from ..staging import EpisodeStaging +from ..validator import classify_vqa_answer +from ..vlm_client import VlmClient + + +def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]: + """Return the relative frame indices to anchor VQA emissions to. + + For each emission tick (every ``1/hz`` seconds), we anchor ``k`` + consecutive frames starting at the tick. Ticks fall on the nearest + available source frame timestamp. + """ + if hz <= 0 or k <= 0 or not frame_timestamps: + return [] + t0 = frame_timestamps[0] + t_last = frame_timestamps[-1] + period = 1.0 / hz + indices: list[int] = [] + t = t0 + while t <= t_last + 1e-9: + # find the index of the nearest frame to t + nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t)) + for offset in range(k): + j = nearest_i + offset + if j >= len(frame_timestamps): + break + if not indices or indices[-1] != j: + indices.append(j) + t += period + # dedupe while preserving order + seen: set[int] = set() + deduped: list[int] = [] + for i in indices: + if i in seen: + continue + seen.add(i) + deduped.append(i) + return deduped + + +@dataclass +class GeneralVqaModule: + """Emit grounded VQA pairs at a timed cadence.""" + + vlm: VlmClient + config: Module3Config + seed: int = 1729 + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + if not record.frame_timestamps: + staging.write("module_3", []) + return + rng = random.Random(f"{self.seed}:{record.episode_index}:vqa") + anchor_idx = _emission_anchor_indices( + record.frame_timestamps, self.config.vqa_emission_hz, self.config.K + ) + rows: list[dict[str, Any]] = [] + for idx in anchor_idx: + ts = float(record.frame_timestamps[idx]) + qtype = rng.choice(self.config.question_types) + qa = self._generate_one(record, qtype) + if qa is None: + continue + question, answer = qa + rows.append( + { + "role": "user", + "content": question, + "style": "vqa", + "timestamp": ts, + "tool_calls": None, + } + ) + rows.append( + { + "role": "assistant", + "content": json.dumps(answer, sort_keys=True), + "style": "vqa", + "timestamp": ts, + "tool_calls": None, + } + ) + staging.write("module_3", rows) + + def _generate_one(self, record: EpisodeRecord, question_type: str) -> tuple[str, dict[str, Any]] | None: + prompt = load_prompt("module_3_vqa").format( + episode_task=record.episode_task, + question_type=question_type, + ) + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + if not isinstance(result, dict): + return None + question = result.get("question") + answer = result.get("answer") + if not isinstance(question, str) or not question.strip(): + return None + if not isinstance(answer, dict): + return None + # The validator will enforce shape; here we just sanity-check that the + # answer matches *some* known shape so we can drop garbage early. + if classify_vqa_answer(answer) is None: + return None + return question.strip(), answer diff --git a/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py new file mode 100644 index 000000000..776cfb79c --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module 2: interjections + paired speech (EVENT styles + speech atoms). + +Two sub-passes: + +1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the + canonical task). No interjection row — the canonical task is already the + user utterance from ``meta/tasks.parquet``. + +2. For mid-episode interruptions, emit a co-timestamped pair: + {role:user, style:interjection, content:} + speech atom (role:assistant, style:None, tool_calls=[say(...)]) + Both rows go in ``language_events`` at the same timestamp. + +Module 1's :meth:`run_plan_updates` reuses Module 2's interjection +timestamps to refresh the ``plan`` row at the same instant. +""" + +from __future__ import annotations + +import random +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +from ..config import Module2Config +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord +from ..staging import EpisodeStaging +from ..vlm_client import VlmClient +from ..writer import speech_atom + + +def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float: + if not frame_timestamps: + return float(t) + return float(min(frame_timestamps, key=lambda f: abs(f - t))) + + +@dataclass +class InterjectionsAndSpeechModule: + """Generate task-start speech and mid-episode interjection/speech pairs.""" + + vlm: VlmClient + config: Module2Config + seed: int = 1729 + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + rows: list[dict[str, Any]] = [] + if record.frame_timestamps: + t0 = float(record.frame_timestamps[0]) + initial = self._initial_speech(record) + if initial: + rows.append(speech_atom(t0, initial)) + rows.extend(self._mid_episode_interjections(record)) + staging.write("module_2", rows) + + def _initial_speech(self, record: EpisodeRecord) -> str | None: + prompt = load_prompt("module_2_initial_speech").format( + episode_task=record.episode_task, + ) + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict) and isinstance(result.get("text"), str): + text = result["text"].strip() + if text: + return text + return None + + def _mid_episode_interjections(self, record: EpisodeRecord) -> list[dict[str, Any]]: + if self.config.max_interjections_per_episode <= 0: + return [] + # Deterministic per-episode RNG so reruns are stable across SLURM jobs. + rng = random.Random(f"{self.seed}:{record.episode_index}:interjection") + candidate_ts = [t for t in record.frame_timestamps if t >= self.config.interjection_min_t] + if not candidate_ts: + return [] + n = min(self.config.max_interjections_per_episode, len(candidate_ts) // 4) + if n <= 0: + return [] + chosen = sorted(rng.sample(candidate_ts, n)) + out: list[dict[str, Any]] = [] + for t in chosen: + t_snap = _snap_to_frame(t, record.frame_timestamps) + current_subtask = record.episode_task + prompt = load_prompt("module_2_interjection").format( + episode_task=record.episode_task, + current_subtask=current_subtask, + timestamp=t_snap, + ) + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + if not isinstance(result, dict): + continue + interjection_text = result.get("interjection") + speech_text = result.get("speech") + if not isinstance(interjection_text, str) or not interjection_text.strip(): + continue + if not isinstance(speech_text, str) or not speech_text.strip(): + continue + out.append( + { + "role": "user", + "content": interjection_text.strip(), + "style": "interjection", + "timestamp": t_snap, + "tool_calls": None, + } + ) + out.append(speech_atom(t_snap, speech_text.strip())) + return out diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py new file mode 100644 index 000000000..47c905e1d --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module 1: subtask decomposition + plan + memory (PERSISTENT styles).""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +from ..config import Module1Config +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord, keyframe_indices +from ..staging import EpisodeStaging +from ..vlm_client import VlmClient + + +def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float: + """Snap an arbitrary float to the nearest exact source frame timestamp.""" + if not frame_timestamps: + return float(t) + nearest = min(frame_timestamps, key=lambda f: abs(f - t)) + return float(nearest) + + +@dataclass +class PlanSubtasksMemoryModule: + """Generate subtask spans, plan, and memory rows. + + All output is persistent (lives in ``language_persistent``): + + - ``subtask`` rows: one per span, stamped at the span's *start* timestamp + (snapped to an exact frame). + - ``plan`` rows: emitted at ``t=0``; refreshed at every interjection + timestamp via :meth:`run_plan_updates` (called by the executor after + Module 2 completes). + - ``memory`` rows: emitted at each subtask boundary (= subtask start + timestamp from the second subtask onward). + """ + + vlm: VlmClient + config: Module1Config + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + rows: list[dict[str, Any]] = [] + subtask_spans = self._generate_subtasks(record) + # subtask rows + for span in subtask_spans: + rows.append( + { + "role": "assistant", + "content": span["text"], + "style": "subtask", + "timestamp": _snap_to_frame(span["start"], record.frame_timestamps), + "tool_calls": None, + } + ) + # plan row at t=0 + plan_text = self._generate_plan(record, subtask_spans) + if plan_text is not None: + t0 = record.frame_timestamps[0] if record.frame_timestamps else 0.0 + rows.append( + { + "role": "assistant", + "content": plan_text, + "style": "plan", + "timestamp": float(t0), + "tool_calls": None, + } + ) + # memory rows at every subtask boundary except the very first start + prior_memory = "" + for i, span in enumerate(subtask_spans[1:], start=1): + completed = subtask_spans[i - 1]["text"] + remaining = [s["text"] for s in subtask_spans[i:]] + mem_text = self._generate_memory(record, prior_memory, completed, remaining) + if mem_text: + ts = _snap_to_frame(span["start"], record.frame_timestamps) + rows.append( + { + "role": "assistant", + "content": mem_text, + "style": "memory", + "timestamp": ts, + "tool_calls": None, + } + ) + prior_memory = mem_text + staging.write("module_1", rows) + + def run_plan_updates( + self, + record: EpisodeRecord, + staging: EpisodeStaging, + interjection_times: Sequence[float], + ) -> None: + """Append additional ``plan`` rows at every interjection timestamp.""" + existing = staging.read("module_1") + spans = self._reconstruct_subtasks_from_rows(existing) + new_rows = list(existing) + for raw_t in interjection_times: + t = _snap_to_frame(raw_t, record.frame_timestamps) + plan_text = self._generate_plan(record, spans, refresh_t=t) + if plan_text is not None: + new_rows.append( + { + "role": "assistant", + "content": plan_text, + "style": "plan", + "timestamp": t, + "tool_calls": None, + } + ) + staging.write("module_1", new_rows) + + @staticmethod + def _reconstruct_subtasks_from_rows(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]: + out = [] + last_t: float | None = None + for row in sorted( + (r for r in rows if r.get("style") == "subtask"), + key=lambda r: float(r["timestamp"]), + ): + t = float(row["timestamp"]) + if last_t is not None: + out[-1]["end"] = t + out.append({"text": row.get("content") or "", "start": t, "end": t}) + last_t = t + return out + + def _generate_subtasks(self, record: EpisodeRecord) -> list[dict[str, Any]]: + if record.row_count == 0 or not record.frame_timestamps: + return [] + episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] + keyframe_local = keyframe_indices(record, self.config.keyframes_per_episode) + prompt = load_prompt("module_1_subtasks").format( + episode_task=record.episode_task, + num_keyframes=len(keyframe_local), + min_subtask_seconds=self.config.min_subtask_seconds, + max_steps=self.config.plan_max_steps, + episode_duration=f"{episode_duration:.3f}", + ) + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + spans = result.get("subtasks") if isinstance(result, dict) else None + if not spans: + return [] + # clamp to [t0, t_last] and sort + t0 = record.frame_timestamps[0] + t_last = record.frame_timestamps[-1] + cleaned: list[dict[str, Any]] = [] + for span in spans: + try: + start = float(span["start"]) + end = float(span["end"]) + text = str(span["text"]).strip() + except (KeyError, ValueError, TypeError): + continue + start = max(t0, min(start, t_last)) + end = max(t0, min(end, t_last)) + if end < start: + start, end = end, start + if not text: + continue + cleaned.append({"text": text, "start": start, "end": end}) + cleaned.sort(key=lambda s: s["start"]) + return cleaned + + def _generate_plan( + self, + record: EpisodeRecord, + subtask_spans: Sequence[dict[str, Any]], + *, + refresh_t: float | None = None, + ) -> str | None: + if not subtask_spans: + return None + subtasks_text = "\n".join(f"- {s['text']}" for s in subtask_spans) + prompt = load_prompt("module_1_plan").format( + episode_task=record.episode_task, + subtasks_text=subtasks_text, + plan_max_steps=self.config.plan_max_steps, + ) + if refresh_t is not None: + prompt += f"\n\n(This is a plan refresh after a user interjection at t={refresh_t:.2f}s.)\n" + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict) and isinstance(result.get("plan"), str): + return result["plan"].strip() + return None + + def _generate_memory( + self, + record: EpisodeRecord, + prior_memory: str, + completed: str, + remaining: Sequence[str], + ) -> str: + prompt = load_prompt("module_1_memory").format( + episode_task=record.episode_task, + prior_memory=prior_memory or "(none)", + completed_subtask=completed, + remaining_subtasks=", ".join(remaining) if remaining else "(none)", + ) + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict) and isinstance(result.get("memory"), str): + return result["memory"].strip() + return "" diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py b/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py new file mode 100644 index 000000000..5ce8e163b --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prompt templates loaded as plain text. + +One file per use site. Templates use ``str.format(**vars)`` substitution; we +intentionally avoid jinja2 here so the templates remain inspectable in +plain editors and roundtrip cleanly through ``ruff format``. +""" + +from __future__ import annotations + +from pathlib import Path + +_DIR = Path(__file__).parent + + +def load(name: str) -> str: + """Read prompt template ``name.txt`` from the ``prompts/`` directory.""" + path = _DIR / f"{name}.txt" + return path.read_text(encoding="utf-8") diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/module_1_memory.txt b/src/lerobot/annotations/steerable_pipeline/prompts/module_1_memory.txt new file mode 100644 index 000000000..6a89ecefa --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/module_1_memory.txt @@ -0,0 +1,25 @@ +You are updating the robot's compressed semantic memory at the boundary of +a completed subtask. + +Reference (verbatim from MEM, Torne 2026): +"Remove or compress information in the language memory whenever +appropriate. Keep ONLY the minimal set of relevant information for future +task execution. Specific object attributes (colors, precise quantities of +each item) get discarded when their details won't affect subsequent +actions. Functional outcomes (where items went, how many) are preserved." + +Concrete example from MEM: + Before: "I put a light green bowl, a dark blue bowl and a bright yellow + bowl into the top right cabinet" + After: "I placed three bowls in the top right cabinet" + +Episode task: "{episode_task}" +Previous memory: {prior_memory} +Just-completed subtask: "{completed_subtask}" +Remaining subtasks (for relevance judgement only): {remaining_subtasks} + +Update the memory. Drop irrelevant detail. Compress completed steps. +Keep WHAT happened, drop HOW. Shorter is better. + +Output strictly valid JSON: + {{ "memory": "" }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/module_1_plan.txt b/src/lerobot/annotations/steerable_pipeline/prompts/module_1_plan.txt new file mode 100644 index 000000000..b0121c977 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/module_1_plan.txt @@ -0,0 +1,18 @@ +You are the high-level planner for a robot demonstrating: "{episode_task}". + +Given the subtask decomposition below, write a concise hierarchical PLAN +the robot should follow. Format the plan as a numbered list, one line per +high-level step. The plan describes the full task; subtasks are the atomic +skills used to execute it. + +Subtasks for context: +{subtasks_text} + +Authoring rules: +- 3 to {plan_max_steps} steps. +- Each step describes one logical chunk of the task, not one motion. +- Steps must be in execution order. +- Plain prose, no JSON, no markdown headers. + +Output strictly valid JSON: + {{ "plan": "1. ...\n2. ...\n3. ..." }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/module_1_subtasks.txt b/src/lerobot/annotations/steerable_pipeline/prompts/module_1_subtasks.txt new file mode 100644 index 000000000..523312123 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/module_1_subtasks.txt @@ -0,0 +1,31 @@ +You are labeling a teleoperated robot demonstration. + +The user originally asked: "{episode_task}" + +You will be shown {num_keyframes} keyframes spaced evenly across the +episode. Decompose the demonstration into a list of consecutive atomic +subtasks the robot performs. + +Authoring rules — based on Hi Robot (Shi 2025) atom granularity and Pi0.7 +(Physical Intelligence 2025) "how, not what" detail: + +- Each subtask is one atomic skill the low-level policy can execute, e.g. + "pick up one piece of lettuce", "place the bowl into the box", + "move the right arm to the left". +- Capture HOW the subtask is performed, not only WHAT — e.g. prefer + "grasp the handle of the sponge with the left hand" to "pick up the + sponge". +- Subtasks are non-overlapping and must cover the full episode in order. +- Each subtask spans at least {min_subtask_seconds} seconds. +- Do not exceed {max_steps} subtasks total. +- Every subtask's [start_time, end_time] must lie within + [0.0, {episode_duration}] seconds. + +Output strictly valid JSON of shape: + + {{ + "subtasks": [ + {{"text": "", "start": , "end": }}, + ... + ] + }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/module_2_initial_speech.txt b/src/lerobot/annotations/steerable_pipeline/prompts/module_2_initial_speech.txt new file mode 100644 index 000000000..6058b1f5c --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/module_2_initial_speech.txt @@ -0,0 +1,10 @@ +The user just asked the robot: "{episode_task}". + +Generate a short verbal acknowledgement the robot would speak back before +beginning the task. Style: confident, friendly, single short sentence. + +Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.", +"OK, starting with the sponge.", "Got it.". + +Output strictly valid JSON: + {{ "text": "" }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/module_2_interjection.txt b/src/lerobot/annotations/steerable_pipeline/prompts/module_2_interjection.txt new file mode 100644 index 000000000..0ecb78f9d --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/module_2_interjection.txt @@ -0,0 +1,27 @@ +You are simulating a user mid-episode interruption for a robot doing: +"{episode_task}". + +Synthesize ONE realistic interruption the user might say at this moment in +the demonstration, plus the robot's verbal acknowledgement. + +Context (Hi Robot, Shi 2025) — interjections fall into one of these +scenario types: +- negative task: "actually skip X" +- situated correction: "that's not trash" +- specific constraint: "use less salt" +- preference: "could you also do Y" + +Interruption rules: +- Must be plausible given the current subtask context. +- Must change the plan in a non-trivial way (a new constraint, skipped + step, or correction). +- One sentence each. + +Current subtask context: {current_subtask} +Time into episode: {timestamp:.2f}s + +Output strictly valid JSON: + {{ + "interjection": "", + "speech": "" + }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/module_3_vqa.txt b/src/lerobot/annotations/steerable_pipeline/prompts/module_3_vqa.txt new file mode 100644 index 000000000..23590b381 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/module_3_vqa.txt @@ -0,0 +1,32 @@ +You are generating a frame-grounded visual question/answer pair for +chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable +Policies — both train policies on grounded features such as bounding box +pixel coordinates, keypoints, counts, attributes, and spatial relations. + +The frame shows a robot working on: "{episode_task}". + +Question types and the EXACT answer JSON shape required for each: + + bbox => {{"detections": [{{"label": "", "bbox_format": "xyxy", + "bbox": [x1, y1, x2, y2]}}, ...]}} + bbox is in pixel coordinates (x_min, y_min, x_max, y_max). + ECoT example: "a white cup [124, 25, 176, 113]". + + keypoint => {{"label": "", "point_format": "xy", + "point": [x, y]}} + + count => {{"label": "", "count": , + "note": ""}} + + attribute => {{"label": "", "attribute": "", + "value": ""}} + + spatial => {{"subject": "", "relation": "", "object": ""}} + +Generate a question of type "{question_type}". Output strictly valid JSON: + + {{ + "question": "", + "answer": + }} diff --git a/src/lerobot/annotations/steerable_pipeline/reader.py b/src/lerobot/annotations/steerable_pipeline/reader.py new file mode 100644 index 000000000..9e0b20fab --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/reader.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Datatrove-shaped reader. + +The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per +episode containing: + +- ``episode_index``: int +- ``frame_timestamps``: tuple[float, ...] +- ``frame_indices``: tuple[int, ...] +- ``episode_task``: str (canonical task from ``meta/tasks.parquet``) +- ``data_path``: pathlib.Path of the source parquet shard +- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand) + +This shape lets each module operate per-episode without loading all parquet +rows into memory at once. It deliberately does not depend on datatrove — +datatrove integration wraps this generator inside a ``PipelineStep`` in +:mod:`.executor`. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pyarrow.parquet as pq + +from lerobot.datasets.utils import DEFAULT_TASKS_PATH + + +@dataclass +class EpisodeRecord: + """Per-episode record yielded by the reader.""" + + episode_index: int + episode_task: str + frame_timestamps: tuple[float, ...] + frame_indices: tuple[int, ...] + data_path: Path + row_offset: int # row offset within the parquet file where this episode starts + row_count: int # number of rows for this episode + + def frames_df(self): # type: ignore[no-untyped-def] + """Lazy-load the pandas slice for this episode.""" + import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra + + table = pq.read_table(self.data_path) + df: pd.DataFrame = table.to_pandas() + slice_ = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(drop=True) + return slice_ + + +def _load_tasks_lookup(root: Path) -> dict[int, str]: + tasks_path = root / DEFAULT_TASKS_PATH + if not tasks_path.exists(): + return {} + table = pq.read_table(tasks_path) + cols = {name: table.column(name).to_pylist() for name in table.column_names} + if "task_index" in cols and "task" in cols: + return dict(zip(cols["task_index"], cols["task"], strict=True)) + raise ValueError(f"meta/tasks.parquet at {tasks_path} missing 'task_index' or 'task'") + + +def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]: + """Yield :class:`EpisodeRecord` for every episode under ``root/data/``. + + Episodes are yielded in ascending ``episode_index`` order. The reader does + not assume a specific chunk/file layout: it scans every ``*.parquet`` + under ``data/`` and groups by ``episode_index``. + """ + tasks = _load_tasks_lookup(root) + data_dir = root / "data" + parquet_files = sorted(data_dir.rglob("*.parquet")) + + only_set = set(only_episodes) if only_episodes is not None else None + + for path in parquet_files: + yield from _iter_one_path(path, tasks, only_set) + + +def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]: + table = pq.read_table(path) + names = table.column_names + if "episode_index" not in names: + return + episode_col = table.column("episode_index").to_pylist() + timestamp_col = ( + table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col) + ) + frame_col = ( + table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col))) + ) + task_col = table.column("task_index").to_pylist() if "task_index" in names else None + + def _build( + ep: int, + start: int, + end: int, + task_idx: int | None, + ts_buf: list[float], + fi_buf: list[int], + ) -> EpisodeRecord | None: + if only_set is not None and ep not in only_set: + return None + task = tasks.get(task_idx, "") if task_idx is not None else "" + return EpisodeRecord( + episode_index=ep, + episode_task=task, + frame_timestamps=tuple(ts_buf), + frame_indices=tuple(fi_buf), + data_path=path, + row_offset=start, + row_count=end - start, + ) + + cur_ep: int | None = None + start_offset = 0 + ts_buf: list[float] = [] + fi_buf: list[int] = [] + cur_task_idx: int | None = None + + for i, ep in enumerate(episode_col): + if cur_ep is None: + cur_ep = ep + start_offset = i + ts_buf = [timestamp_col[i]] + fi_buf = [frame_col[i]] + cur_task_idx = task_col[i] if task_col is not None else None + continue + if ep != cur_ep: + rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf) + if rec is not None: + yield rec + cur_ep = ep + start_offset = i + ts_buf = [timestamp_col[i]] + fi_buf = [frame_col[i]] + cur_task_idx = task_col[i] if task_col is not None else None + else: + ts_buf.append(timestamp_col[i]) + fi_buf.append(frame_col[i]) + + if cur_ep is not None: + rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf) + if rec is not None: + yield rec + + +def gather_data_paths(root: Path) -> list[Path]: + """Return every ``data/chunk-*/file-*.parquet`` path under ``root``.""" + return sorted((root / "data").rglob("*.parquet")) + + +def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]: + """Return ``{episode_index: (row_offset, row_count)}`` for one parquet.""" + table = pq.read_table(path, columns=["episode_index"]) + episode_col = table.column("episode_index").to_pylist() + out: dict[int, tuple[int, int]] = {} + cur_ep: int | None = None + start = 0 + for i, ep in enumerate(episode_col): + if cur_ep is None: + cur_ep = ep + start = i + continue + if ep != cur_ep: + out[cur_ep] = (start, i - start) + cur_ep = ep + start = i + if cur_ep is not None: + out[cur_ep] = (start, len(episode_col) - start) + return out + + +def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]: + """Return ``k`` evenly spaced row indices into the episode (relative).""" + n = record.row_count + if k <= 0 or n == 0: + return [] + if k >= n: + return list(range(n)) + step = (n - 1) / (k - 1) if k > 1 else 0.0 + return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2] + + +def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None: + """Find the parquet file containing ``episode_index`` and its slice bounds.""" + for path in gather_data_paths(root): + offsets = episode_offsets_per_path(path) + if episode_index in offsets: + start, count = offsets[episode_index] + return path, start, count + return None + + +def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]: + """Return the parquet path and per-frame timestamps for ``episode_index``.""" + found = lookup_data_path(root, episode_index) + if found is None: + raise ValueError(f"Episode {episode_index} not found under {root}/data/") + path, start, count = found + table = pq.read_table(path, columns=["timestamp"]) + timestamps = table.column("timestamp").to_pylist()[start : start + count] + return path, [float(t) for t in timestamps] diff --git a/src/lerobot/annotations/steerable_pipeline/staging.py b/src/lerobot/annotations/steerable_pipeline/staging.py new file mode 100644 index 000000000..02b957340 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/staging.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Per-episode staging. + +Each module writes its raw output as a JSONL file under +``/episode_{ep:06d}/.jsonl``. The writer reads back this +staging tree and partitions rows into the two language columns. + +JSONL is preferred over parquet here because the staging artifact is meant to +be human-inspectable, easy to diff between prompt iterations, and trivially +appended to. The final dataset format is parquet; staging is just an +intermediate. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +ModuleName = str + +_MODULES: tuple[ModuleName, ...] = ( + "module_1", + "module_2", + "module_3", +) + + +@dataclass +class EpisodeStaging: + """Filesystem layout for a single episode's staged module outputs.""" + + root: Path + episode_index: int + + @property + def episode_dir(self) -> Path: + return self.root / f"episode_{self.episode_index:06d}" + + def path_for(self, module: ModuleName) -> Path: + if module not in _MODULES: + raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}") + return self.episode_dir / f"{module}.jsonl" + + def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path: + path = self.path_for(module) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=False, sort_keys=True)) + f.write("\n") + return path + + def read(self, module: ModuleName) -> list[dict[str, Any]]: + path = self.path_for(module) + if not path.exists(): + return [] + out: list[dict[str, Any]] = [] + with path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + out.append(json.loads(line)) + return out + + def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]: + return {m: self.read(m) for m in _MODULES} + + def has(self, module: ModuleName) -> bool: + return self.path_for(module).exists() + + +def iter_staged_episodes(root: Path) -> Iterator[int]: + """Yield episode indices for which any staging artifact exists.""" + if not root.exists(): + return + for child in sorted(root.iterdir()): + if child.is_dir() and child.name.startswith("episode_"): + try: + yield int(child.name.removeprefix("episode_")) + except ValueError: + continue diff --git a/src/lerobot/annotations/steerable_pipeline/validator.py b/src/lerobot/annotations/steerable_pipeline/validator.py new file mode 100644 index 000000000..ccc79bc38 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/validator.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pre-write validation against staged outputs. + +Runs after Modules 1–3 have all written their per-episode artifacts but +*before* the writer rewrites parquet shards. The validator never touches +parquet; it only inspects the staging tree and the source frame timestamps +exposed by :class:`EpisodeRecord`. + +Checks (per the plan's "Intermediate staging and validation" section): + +- exact timestamp alignment against source frame timestamps +- no orphan speech / interjection pairs +- plan / memory emission consistency (events have a paired persistent row) +- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count / + attribute / spatial) +- every row maps to its correct column under :func:`column_for_style` +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Iterable, Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from lerobot.datasets.language import ( + LANGUAGE_EVENTS, + LANGUAGE_PERSISTENT, + column_for_style, +) + +from .reader import EpisodeRecord +from .staging import EpisodeStaging + +logger = logging.getLogger(__name__) + + +@dataclass +class ValidationReport: + """Outcome of one validation pass across all episodes.""" + + errors: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + episodes_checked: int = 0 + + @property + def ok(self) -> bool: + return not self.errors + + def add_error(self, message: str) -> None: + self.errors.append(message) + + def add_warning(self, message: str) -> None: + self.warnings.append(message) + + def summary(self) -> str: + return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}" + + +VQA_ANSWER_SHAPES: dict[str, set[str]] = { + "bbox": {"detections"}, + "keypoint": {"label", "point_format", "point"}, + "count": {"label", "count"}, + "attribute": {"label", "attribute", "value"}, + "spatial": {"subject", "relation", "object"}, +} + + +def classify_vqa_answer(payload: Any) -> str | None: + """Best-effort classification of a VQA answer payload to a question type.""" + if not isinstance(payload, dict): + return None + keys = set(payload.keys()) + for kind, required in VQA_ANSWER_SHAPES.items(): + if required.issubset(keys): + return kind + return None + + +@dataclass +class StagingValidator: + """Walks the staging tree and produces a :class:`ValidationReport`.""" + + timestamp_atol: float = 0.0 # exact-match by default + + def validate( + self, + records: Sequence[EpisodeRecord], + staging_dir: Path, + ) -> ValidationReport: + report = ValidationReport() + for record in records: + self._validate_episode(record, staging_dir, report) + report.episodes_checked += 1 + return report + + def _validate_episode( + self, + record: EpisodeRecord, + staging_dir: Path, + report: ValidationReport, + ) -> None: + staging = EpisodeStaging(staging_dir, record.episode_index) + staged = staging.read_all() + all_rows: list[dict[str, Any]] = [] + for module_name, rows in staged.items(): + for row in rows: + row = {**row, "_module": module_name} + all_rows.append(row) + + frame_ts = set(record.frame_timestamps) + + events: list[dict[str, Any]] = [] + persistent: list[dict[str, Any]] = [] + for row in all_rows: + self._check_column_routing(row, report, record.episode_index) + if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT: + persistent.append(row) + else: + events.append(row) + + for row in events: + self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index) + + self._check_speech_interjection_pairs(events, report, record.episode_index) + self._check_plan_memory_consistency(persistent, events, report, record.episode_index) + self._check_vqa_json(events, report, record.episode_index) + + def _check_column_routing( + self, + row: dict[str, Any], + report: ValidationReport, + episode_index: int, + ) -> None: + style = row.get("style") + module = row.get("_module") + try: + target_col = column_for_style(style) + except ValueError: + report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}") + return + if module == "module_1" and target_col != LANGUAGE_PERSISTENT: + report.add_error( + f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)" + ) + if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS: + report.add_error( + f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)" + ) + + def _check_event_timestamp_alignment( + self, + row: dict[str, Any], + frame_ts: set[float], + report: ValidationReport, + episode_index: int, + ) -> None: + ts = row.get("timestamp") + if ts is None: + report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}") + return + if self.timestamp_atol == 0.0: + if float(ts) not in frame_ts: + report.add_error( + f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp" + ) + else: + if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts): + report.add_error( + f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame" + ) + + def _check_speech_interjection_pairs( + self, + events: Iterable[dict[str, Any]], + report: ValidationReport, + episode_index: int, + ) -> None: + speech_ts: dict[float, int] = {} + interjection_ts: dict[float, int] = {} + for row in events: + ts = row.get("timestamp") + if ts is None: + continue + ts_f = float(ts) + if row.get("style") is None and row.get("role") == "assistant": + speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1 + if row.get("style") == "interjection": + interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1 + + for ts in interjection_ts: + if ts not in speech_ts: + report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom") + + def _check_plan_memory_consistency( + self, + persistent: Sequence[dict[str, Any]], + events: Sequence[dict[str, Any]], + report: ValidationReport, + episode_index: int, + ) -> None: + plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"}) + memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"}) + subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"}) + interjection_ts = sorted( + { + float(r["timestamp"]) + for r in events + if r.get("style") == "interjection" and r.get("timestamp") is not None + } + ) + + if persistent and not plan_ts: + report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted") + # every interjection should have a same-timestamp plan refresh + for ts in interjection_ts: + if ts not in set(plan_ts): + report.add_error( + f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update" + ) + # memory should be emitted at subtask boundaries (subset relation) + if memory_ts and subtask_ts: + mem_set = set(memory_ts) + sub_set = set(subtask_ts) + stray = sorted(mem_set - sub_set) + if stray: + report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary") + + def _check_vqa_json( + self, + events: Iterable[dict[str, Any]], + report: ValidationReport, + episode_index: int, + ) -> None: + for row in events: + if row.get("style") != "vqa" or row.get("role") != "assistant": + continue + content = row.get("content") + if content is None: + report.add_error( + f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content" + ) + continue + try: + payload = json.loads(content) + except (TypeError, ValueError) as exc: + report.add_error( + f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}" + ) + continue + shape = classify_vqa_answer(payload) + if shape is None: + report.add_error( + f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}" + ) diff --git a/src/lerobot/annotations/steerable_pipeline/vlm_client.py b/src/lerobot/annotations/steerable_pipeline/vlm_client.py new file mode 100644 index 000000000..923445f35 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/vlm_client.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared Qwen-VL client. + +The pipeline uses a single shared VLM across modules. vLLM is preferred when +available (high throughput, JSON-guided decoding); transformers is the +fallback. A ``stub`` backend is used for unit tests so fixtures never call +into a real model. + +The client speaks one method, :meth:`VlmClient.generate_json`, which: + +- accepts a list of OpenAI/HF-style multimodal messages, +- requests JSON output (``json_mode=True`` enables guided decoding when the + backend supports it), +- batches requests transparently, +- and reprompts once on a JSON parse failure with an inline correction + message before raising. +""" + +from __future__ import annotations + +import json +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any, Protocol + +from .config import VlmConfig + + +class VlmClient(Protocol): + """Protocol every backend must implement.""" + + def generate_json( + self, + messages_batch: Sequence[Sequence[dict[str, Any]]], + *, + max_new_tokens: int | None = None, + temperature: float | None = None, + ) -> list[Any]: + """Generate one JSON-decoded response per messages list.""" + + +@dataclass +class StubVlmClient: + """Deterministic stub used in unit tests. + + A test passes a callable that maps the *last user message text* (or, if + that is empty, the full message list) to a JSON-serializable response. + """ + + responder: Callable[[Sequence[dict[str, Any]]], Any] + + def generate_json( + self, + messages_batch: Sequence[Sequence[dict[str, Any]]], + *, + max_new_tokens: int | None = None, + temperature: float | None = None, + ) -> list[Any]: + return [self.responder(list(messages)) for messages in messages_batch] + + +def _strip_to_json(text: str) -> Any: + text = text.strip() + if text.startswith("```"): + # tolerate ```json ... ``` fences from chat-tuned backbones + first = text.find("\n") + last = text.rfind("```") + if first != -1 and last != -1 and last > first: + text = text[first + 1 : last].strip() + return json.loads(text) + + +@dataclass +class _GenericTextClient: + """Wraps any text-generation callable in JSON-mode + one-retry semantics.""" + + generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]] + config: VlmConfig + + def generate_json( + self, + messages_batch: Sequence[Sequence[dict[str, Any]]], + *, + max_new_tokens: int | None = None, + temperature: float | None = None, + ) -> list[Any]: + max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens + temp = temperature if temperature is not None else self.config.temperature + raw = self.generate_text(messages_batch, max_tok, temp) + out: list[Any] = [] + for messages, text in zip(messages_batch, raw, strict=True): + try: + out.append(_strip_to_json(text)) + continue + except (ValueError, json.JSONDecodeError): + pass + retry = list(messages) + [ + {"role": "assistant", "content": text}, + { + "role": "user", + "content": ( + "Your previous reply was not valid JSON. " + "Reply with strictly valid JSON, no prose, no fences." + ), + }, + ] + retry_text = self.generate_text([retry], max_tok, temp)[0] + out.append(_strip_to_json(retry_text)) + return out + + +def make_vlm_client(config: VlmConfig) -> VlmClient: + """Build the shared VLM client per the configured backend. + + For ``stub``, callers should construct :class:`StubVlmClient` directly with + a responder callable. ``stub`` here is rejected to make accidental misuse + obvious. + """ + if config.backend == "stub": + raise ValueError( + "Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients." + ) + if config.backend == "vllm": + return _make_vllm_client(config) + if config.backend == "transformers": + return _make_transformers_client(config) + raise ValueError(f"Unknown VLM backend: {config.backend!r}") + + +def _make_vllm_client(config: VlmConfig) -> VlmClient: + try: + from vllm import LLM, SamplingParams # type: ignore[import-not-found] + except ImportError as exc: + raise ImportError( + "vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`." + ) from exc + llm = LLM(model=config.model_id) + + def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]: + params = SamplingParams( + max_tokens=max_tok, + temperature=temp, + guided_decoding={"json": {}} if config.json_mode else None, + ) + prompts = [_messages_to_prompt(m) for m in batch] + outputs = llm.generate(prompts, params) + return [o.outputs[0].text for o in outputs] + + return _GenericTextClient(_gen, config) + + +def _make_transformers_client(config: VlmConfig) -> VlmClient: + try: + import torch # type: ignore[import-not-found] + from transformers import AutoModelForVision2Seq, AutoProcessor # type: ignore[import-not-found] + except ImportError as exc: + raise ImportError("transformers + torch are required for backend='transformers'.") from exc + processor = AutoProcessor.from_pretrained(config.model_id) + model = AutoModelForVision2Seq.from_pretrained(config.model_id, torch_dtype="auto") + model.eval() + + def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]: + outs: list[str] = [] + for messages in batch: + text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + inputs = processor(text=[text], return_tensors="pt").to(model.device) + with torch.no_grad(): + gen = model.generate( + **inputs, + max_new_tokens=max_tok, + temperature=temp, + do_sample=temp > 0.0, + ) + decoded = processor.batch_decode( + gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True + )[0] + outs.append(decoded) + return outs + + return _GenericTextClient(_gen, config) + + +def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any: + """Pass-through hook used by the vllm backend. + + vllm exposes its own multimodal entry points that vary by version; for the + base flow we simply forward the raw message list and let the caller's + custom backend handle templating. Real deployments override this. + """ + return list(messages) diff --git a/src/lerobot/annotations/steerable_pipeline/writer.py b/src/lerobot/annotations/steerable_pipeline/writer.py new file mode 100644 index 000000000..c83a2b168 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/writer.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Final parquet rewrite. + +For every episode the writer: + +1. reads the staged module outputs, +2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event + slice (EVENT_ONLY_STYLES + style=None tool-call atoms), +3. sorts each slice deterministically, +4. broadcasts the persistent slice across every frame in the episode, +5. for each frame, materializes the sublist of event rows whose timestamp + exactly equals that frame's timestamp, +6. drops the legacy ``subtask_index`` column, +7. adds a top-level ``tools`` column containing the JSON schema for ``say``, +8. writes the parquet shard back in place. + +Invariants enforced here (and re-checked by the validator): + +- per-episode persistent slice is byte-identical across every frame; +- ``language_events`` rows on a frame all have ``timestamp == frame_ts`` + (timestamps come straight from the source parquet — never recomputed); +- every row passes ``column_for_style(style)``. +""" + +from __future__ import annotations + +import json +import logging +from collections import defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pyarrow as pa +import pyarrow.parquet as pq + +from lerobot.datasets.language import ( + EVENT_ONLY_STYLES, + LANGUAGE_EVENTS, + LANGUAGE_PERSISTENT, + PERSISTENT_STYLES, + column_for_style, +) + +from .reader import EpisodeRecord +from .staging import EpisodeStaging + +logger = logging.getLogger(__name__) + + +SAY_TOOL_SCHEMA: dict[str, Any] = { + "type": "function", + "function": { + "name": "say", + "description": "Speak a short utterance to the user via the TTS executor.", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The verbatim text to speak.", + } + }, + "required": ["text"], + }, + }, +} + + +def _row_persistent_sort_key(row: dict[str, Any]) -> tuple: + return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "") + + +def _row_event_sort_key(row: dict[str, Any]) -> tuple: + # events are bucketed per-frame, but within a frame we still want determinism + return (row.get("style") or "", row.get("role") or "") + + +def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]: + """Coerce a staged row into the persistent column's struct shape.""" + style = row.get("style") + if style not in PERSISTENT_STYLES: + raise ValueError( + f"persistent slice contains row with non-persistent style {style!r}; " + "row would be misrouted under column_for_style()" + ) + if "timestamp" not in row: + raise ValueError(f"persistent row missing timestamp: {row!r}") + return { + "role": str(row["role"]), + "content": None if row.get("content") is None else str(row["content"]), + "style": style, + "timestamp": float(row["timestamp"]), + "tool_calls": _normalize_tool_calls(row.get("tool_calls")), + } + + +def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]: + """Coerce a staged row into the event column's struct shape (no timestamp).""" + style = row.get("style") + if style is not None and style not in EVENT_ONLY_STYLES: + raise ValueError( + f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}" + ) + if column_for_style(style) != LANGUAGE_EVENTS: + raise ValueError(f"event row with style {style!r} would not route to language_events") + return { + "role": str(row["role"]), + "content": None if row.get("content") is None else str(row["content"]), + "style": style, + "tool_calls": _normalize_tool_calls(row.get("tool_calls")), + } + + +def _normalize_tool_calls(value: Any) -> list[Any] | None: + if value is None: + return None + if not isinstance(value, list): + raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}") + return list(value) + + +def _validate_atom_invariants(row: dict[str, Any]) -> None: + """At-least-one of content/tool_calls; style=None implies tool_calls.""" + has_content = row.get("content") is not None + has_tools = row.get("tool_calls") is not None + if not (has_content or has_tools): + raise ValueError(f"row has neither content nor tool_calls: {row!r}") + if row.get("style") is None and not has_tools: + raise ValueError(f"style=None requires tool_calls: {row!r}") + + +def _validate_speech_atom(row: dict[str, Any]) -> None: + """Speech atoms: role=assistant, style=None, content=None, say tool call.""" + if row.get("style") is not None: + return # not a speech atom + if row.get("role") != "assistant": + raise ValueError(f"speech atom must have role=assistant: {row!r}") + if row.get("content") is not None: + raise ValueError(f"speech atom must have content=null: {row!r}") + tool_calls = row.get("tool_calls") + if not tool_calls or not isinstance(tool_calls, list): + raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}") + first = tool_calls[0] + if not isinstance(first, dict): + raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}") + if first.get("type") != "function": + raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}") + fn = first.get("function") or {} + if fn.get("name") != "say": + raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}") + args = fn.get("arguments") or {} + if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str): + raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}") + + +@dataclass +class LanguageColumnsWriter: + """Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns.""" + + drop_existing_subtask_index: bool = True + + def write_all( + self, + records: Sequence[EpisodeRecord], + staging_dir: Path, + root: Path, + ) -> list[Path]: + episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list) + for record in records: + episodes_by_path[record.data_path].append(record) + + written: list[Path] = [] + for path, eps in episodes_by_path.items(): + self._rewrite_one(path, eps, staging_dir, root) + written.append(path) + return written + + def _rewrite_one( + self, + path: Path, + episodes: Sequence[EpisodeRecord], + staging_dir: Path, + root: Path, + ) -> None: + table = pq.read_table(path) + n_rows = table.num_rows + + # Ensure we cover every episode in the file. Episodes that don't have + # staging artifacts are passed through with empty annotation lists — + # this keeps the writer idempotent and safe for partial reruns. + staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {} + for record in episodes: + staging = EpisodeStaging(staging_dir, record.episode_index) + staged_per_ep[record.episode_index] = staging.read_all() + + persistent_by_ep: dict[int, list[dict[str, Any]]] = {} + events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {} + + for ep_index, ep_staged in staged_per_ep.items(): + persistent_rows: list[dict[str, Any]] = [] + event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed + for _module_name, rows in ep_staged.items(): + for row in rows: + style = row.get("style") + if column_for_style(style) == LANGUAGE_PERSISTENT: + persistent_rows.append(row) + else: + event_rows.append(row) + + persistent_rows.sort(key=_row_persistent_sort_key) + normalized_persistent = [] + for r in persistent_rows: + _validate_atom_invariants(r) + _validate_speech_atom(r) + normalized_persistent.append(_normalize_persistent_row(r)) + persistent_by_ep[ep_index] = normalized_persistent + + buckets: dict[float, list[dict[str, Any]]] = defaultdict(list) + for r in event_rows: + _validate_atom_invariants(r) + _validate_speech_atom(r) + ts = float(r["timestamp"]) + buckets[ts].append(_normalize_event_row(r)) + for ts in list(buckets.keys()): + buckets[ts].sort(key=_row_event_sort_key) + events_by_ep_ts[ep_index] = buckets + + episode_col = ( + table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None + ) + ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None + if episode_col is None or ts_col is None: + raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.") + + per_row_persistent: list[list[dict[str, Any]]] = [] + per_row_events: list[list[dict[str, Any]]] = [] + for i in range(n_rows): + ep = episode_col[i] + ts = float(ts_col[i]) + per_row_persistent.append(persistent_by_ep.get(ep, [])) + buckets = events_by_ep_ts.get(ep, {}) + per_row_events.append(buckets.get(ts, [])) + + new_table = self._materialize_table( + table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index + ) + pq.write_table(new_table, path) + + def _materialize_table( + self, + table: pa.Table, + persistent: list[list[dict[str, Any]]], + events: list[list[dict[str, Any]]], + *, + drop_old: bool, + ) -> pa.Table: + cols = [] + names = [] + for name in table.column_names: + if drop_old and name == "subtask_index": + continue + if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS, "tools"): + continue # we'll re-add canonical versions + cols.append(table.column(name)) + names.append(name) + + # We let pyarrow infer struct/list schema rather than passing the + # canonical type from `lerobot.datasets.language` directly: that type + # uses `pa.json_()` for the `tool_calls` element type, which + # `pa.array(..., type=...)` cannot materialize from Python lists on + # current pyarrow versions. The inferred schema round-trips through + # parquet and `LeRobotDataset` correctly — see PR 1's + # `tests/datasets/test_language.py` which exercises the same flow. + persistent_arr = pa.array(persistent) + events_arr = pa.array(events) + + cols.extend([persistent_arr, events_arr]) + names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS]) + + # Dataset-level tools column. Store the JSON schema as a string per + # row (broadcast-identical, parquet dictionary-encodes it) — string + # storage avoids requiring pa.json_() on every consumer. + tools_json = json.dumps([SAY_TOOL_SCHEMA], sort_keys=True) + tools_arr = pa.array([tools_json] * table.num_rows, type=pa.string()) + cols.append(tools_arr) + names.append("tools") + + return pa.Table.from_arrays(cols, names=names) + + +def speech_atom(timestamp: float, text: str) -> dict[str, Any]: + """Build a canonical speech tool-call atom for the events column.""" + return { + "role": "assistant", + "content": None, + "style": None, + "timestamp": float(timestamp), + "tool_calls": [ + { + "type": "function", + "function": { + "name": "say", + "arguments": {"text": text}, + }, + } + ], + } + + +def normalize_rows_for_writer( + rows: Iterable[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Helper used by tests/validators to partition a flat row list into + (persistent_rows, event_rows) using ``column_for_style``. + """ + persistent: list[dict[str, Any]] = [] + events: list[dict[str, Any]] = [] + for row in rows: + if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT: + persistent.append(row) + else: + events.append(row) + return persistent, events diff --git a/src/lerobot/scripts/lerobot_annotate.py b/src/lerobot/scripts/lerobot_annotate.py new file mode 100644 index 000000000..b71d3b3ba --- /dev/null +++ b/src/lerobot/scripts/lerobot_annotate.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``lerobot-annotate`` — populate ``language_persistent`` and +``language_events`` columns on a LeRobot dataset. + +Annotations live directly in ``data/chunk-*/file-*.parquet``: there is no +flavor namespace and no sidecar tree. Multiple revisions of the same dataset +mean multiple dataset copies. + +Example: + + uv run lerobot-annotate \\ + --root=/path/to/dataset \\ + --vlm.backend=transformers \\ + --vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct +""" + +import logging +from pathlib import Path + +from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig +from lerobot.annotations.steerable_pipeline.executor import Executor +from lerobot.annotations.steerable_pipeline.modules import ( + GeneralVqaModule, + InterjectionsAndSpeechModule, + PlanSubtasksMemoryModule, +) +from lerobot.annotations.steerable_pipeline.validator import StagingValidator +from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client +from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter +from lerobot.configs import parser + +logger = logging.getLogger(__name__) + + +def _resolve_root(cfg: AnnotationPipelineConfig) -> Path: + if cfg.root is not None: + return Path(cfg.root) + if cfg.repo_id is not None: + from huggingface_hub import snapshot_download + + return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset")) + raise ValueError("Either --root or --repo_id must be provided.") + + +@parser.wrap() +def annotate(cfg: AnnotationPipelineConfig) -> None: + """Run the steerable annotation pipeline against a dataset.""" + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + root = _resolve_root(cfg) + logger.info("annotate: root=%s", root) + + vlm = make_vlm_client(cfg.vlm) + module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1) + module_2 = InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed) + module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed) + writer = LanguageColumnsWriter() + validator = StagingValidator() + + executor = Executor( + config=cfg, + module_1=module_1, + module_2=module_2, + module_3=module_3, + writer=writer, + validator=validator, + ) + summary = executor.run(root) + logger.info("annotate: wrote %d shard(s)", len(summary.written_paths)) + for phase in summary.phases: + logger.info( + "annotate: phase=%s processed=%d skipped=%d", + phase.name, + phase.episodes_processed, + phase.episodes_skipped, + ) + if summary.validation_report.warnings: + for w in summary.validation_report.warnings: + logger.warning(w) + + +def main() -> None: + annotate() + + +if __name__ == "__main__": + main() diff --git a/tests/annotations/__init__.py b/tests/annotations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/annotations/_helpers.py b/tests/annotations/_helpers.py new file mode 100644 index 000000000..6a6290a1d --- /dev/null +++ b/tests/annotations/_helpers.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helpers shared across annotation-pipeline tests.""" + +from __future__ import annotations + +import json +from typing import Any + +from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient + + +def make_canned_responder( + responses_by_marker: dict[str, Any], + default: Any = None, +) -> StubVlmClient: + """Return a stub that picks a response by inspecting the user prompt. + + For each call the responder examines the last user-message text and + returns the response keyed by the first marker substring it contains. + Falls back to ``default`` if no marker matches. + """ + + def responder(messages: list[dict[str, Any]]) -> Any: + last_user_text = "" + for message in messages: + if message.get("role") != "user": + continue + content = message.get("content") + if isinstance(content, str): + last_user_text = content + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + last_user_text = block.get("text", "") + for marker, response in responses_by_marker.items(): + if marker in last_user_text: + return response + return default + + return StubVlmClient(responder=responder) + + +def encode_vqa_answer(payload: dict[str, Any]) -> str: + return json.dumps(payload, sort_keys=True) diff --git a/tests/annotations/conftest.py b/tests/annotations/conftest.py new file mode 100644 index 000000000..5ffcc857d --- /dev/null +++ b/tests/annotations/conftest.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared fixtures for annotation-pipeline tests. + +Builds a minimal LeRobot-shaped dataset on disk so writer/validator tests +can exercise real parquet reads and writes without needing a checked-in +LFS dataset. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + + +def _make_episode_table( + episode_index: int, + num_frames: int, + *, + fps: int = 10, + task_index: int = 0, +) -> pa.Table: + timestamps = [round(i / fps, 6) for i in range(num_frames)] + frame_indices = list(range(num_frames)) + return pa.Table.from_pydict( + { + "episode_index": [episode_index] * num_frames, + "frame_index": frame_indices, + "timestamp": timestamps, + "task_index": [task_index] * num_frames, + "subtask_index": [0] * num_frames, # legacy column the writer must drop + } + ) + + +def _build_dataset(root: Path, episode_specs: list[tuple[int, int, str]], *, fps: int = 10) -> Path: + """Create a fixture dataset under ``root``. + + ``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``. + Each episode goes into its own ``data/chunk-000/file-{ep:03d}.parquet`` + so the writer's per-shard rewrite path is exercised. + """ + data_dir = root / "data" / "chunk-000" + data_dir.mkdir(parents=True, exist_ok=True) + tasks = {} + for episode_index, num_frames, task_text in episode_specs: + task_index = len(tasks) + if task_text not in tasks.values(): + tasks[task_index] = task_text + else: + task_index = next(k for k, v in tasks.items() if v == task_text) + table = _make_episode_table(episode_index, num_frames, fps=fps, task_index=task_index) + path = data_dir / f"file-{episode_index:03d}.parquet" + pq.write_table(table, path) + + meta_dir = root / "meta" + meta_dir.mkdir(parents=True, exist_ok=True) + tasks_table = pa.Table.from_pydict( + { + "task_index": list(tasks.keys()), + "task": list(tasks.values()), + } + ) + pq.write_table(tasks_table, meta_dir / "tasks.parquet") + + info = { + "codebase_version": "v3.1", + "fps": fps, + "total_episodes": len(episode_specs), + } + (meta_dir / "info.json").write_text(json.dumps(info, indent=2)) + + return root + + +@pytest.fixture +def fixture_dataset_root(tmp_path: Path) -> Path: + """A tiny dataset with two episodes, 12 frames each at 10 fps.""" + return _build_dataset( + tmp_path / "ds", + episode_specs=[ + (0, 12, "Could you tidy the kitchen please?"), + (1, 12, "Please clean up the kitchen"), + ], + fps=10, + ) + + +@pytest.fixture +def single_episode_root(tmp_path: Path) -> Path: + return _build_dataset( + tmp_path / "ds_one", + episode_specs=[(0, 30, "Pour water from the bottle into the cup.")], + fps=10, + ) diff --git a/tests/annotations/run_e2e_smoke.py b/tests/annotations/run_e2e_smoke.py new file mode 100644 index 000000000..6d35266f7 --- /dev/null +++ b/tests/annotations/run_e2e_smoke.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Opt-in E2E smoke run for ``make annotation-e2e``. + +Builds the same fixture used by the pytest suite, runs the full +annotation pipeline against it with a stub VLM, and prints a short report. +This is intentionally not a pytest test — it exercises the CLI plumbing +without depending on conftest.py fixtures. +""" + +from __future__ import annotations + +import json +import sys +import tempfile +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq + +from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig +from lerobot.annotations.steerable_pipeline.executor import Executor +from lerobot.annotations.steerable_pipeline.modules import ( + GeneralVqaModule, + InterjectionsAndSpeechModule, + PlanSubtasksMemoryModule, +) +from lerobot.annotations.steerable_pipeline.validator import StagingValidator +from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient +from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter + + +def _build_dataset(root: Path) -> Path: + data_dir = root / "data" / "chunk-000" + data_dir.mkdir(parents=True, exist_ok=True) + n = 30 + timestamps = [round(i / 10, 6) for i in range(n)] + table = pa.Table.from_pydict( + { + "episode_index": [0] * n, + "frame_index": list(range(n)), + "timestamp": timestamps, + "task_index": [0] * n, + "subtask_index": [0] * n, + } + ) + pq.write_table(table, data_dir / "file-000.parquet") + meta = root / "meta" + meta.mkdir(parents=True, exist_ok=True) + pq.write_table( + pa.Table.from_pydict({"task_index": [0], "task": ["Pour water into the cup."]}), + meta / "tasks.parquet", + ) + (meta / "info.json").write_text(json.dumps({"codebase_version": "v3.1", "fps": 10})) + return root + + +def _stub_responder(messages): + text = "" + for m in messages: + if m.get("role") == "user": + content = m.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + elif isinstance(content, str): + text = content + if "Decompose the demonstration" in text: + return { + "subtasks": [ + {"text": "grasp the bottle", "start": 0.0, "end": 1.0}, + {"text": "pour into the cup", "start": 1.0, "end": 2.0}, + {"text": "place the bottle down", "start": 2.0, "end": 3.0}, + ] + } + if "concise hierarchical PLAN" in text: + return {"plan": "1. grasp\n2. pour\n3. place"} + if "Update the memory" in text: + return {"memory": "poured once"} + if "acknowledgement the robot" in text: + return {"text": "Sure."} + if "ONE realistic interruption" in text: + return {"interjection": "use less water", "speech": "Using less water."} + if "frame-grounded visual question" in text: + return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}} + return None + + +def main() -> int: + with tempfile.TemporaryDirectory() as tmp: + root = _build_dataset(Path(tmp) / "ds") + vlm = StubVlmClient(responder=_stub_responder) + cfg = AnnotationPipelineConfig() + executor = Executor( + config=cfg, + module_1=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1), + module_2=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed), + module_3=GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed), + writer=LanguageColumnsWriter(), + validator=StagingValidator(), + ) + summary = executor.run(root) + print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}") + print(f"validation: {summary.validation_report.summary()}") + print(f"shards rewritten: {len(summary.written_paths)}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/annotations/test_modules.py b/tests/annotations/test_modules.py new file mode 100644 index 000000000..27ab92e3d --- /dev/null +++ b/tests/annotations/test_modules.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module 1/2/3 unit tests with stubbed VLMs.""" + +from __future__ import annotations + +import json +from pathlib import Path + +from lerobot.annotations.steerable_pipeline.config import ( + Module1Config, + Module2Config, + Module3Config, +) +from lerobot.annotations.steerable_pipeline.modules import ( + GeneralVqaModule, + InterjectionsAndSpeechModule, + PlanSubtasksMemoryModule, +) +from lerobot.annotations.steerable_pipeline.reader import iter_episodes +from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging + +from ._helpers import make_canned_responder + + +def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None: + vlm = make_canned_responder( + { + "Decompose the demonstration": { + "subtasks": [ + {"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4}, + {"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8}, + {"text": "place the sponge into the sink", "start": 0.8, "end": 1.1}, + ] + }, + "write a concise hierarchical PLAN": {"plan": "1. grasp\n2. wipe\n3. place"}, + "Update the memory": {"memory": "wiped the counter once"}, + }, + ) + module = PlanSubtasksMemoryModule(vlm=vlm, config=Module1Config()) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("module_1") + + styles = {r["style"] for r in rows} + assert {"subtask", "plan", "memory"}.issubset(styles) + # subtask timestamps must be exact frame timestamps + frame_set = set(record.frame_timestamps) + for row in rows: + assert row["timestamp"] in frame_set + # exactly one plan row at t0 + plan_rows = [r for r in rows if r["style"] == "plan"] + assert len(plan_rows) == 1 + assert plan_rows[0]["timestamp"] == record.frame_timestamps[0] + + +def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None: + vlm = make_canned_responder( + {"acknowledgement the robot": {"text": "Sure, on it."}}, + ) + module = InterjectionsAndSpeechModule( + vlm=vlm, + config=Module2Config(max_interjections_per_episode=0), + ) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("module_2") + assert len(rows) == 1 + only = rows[0] + assert only["role"] == "assistant" + assert only["style"] is None + assert only["content"] is None + assert only["timestamp"] == record.frame_timestamps[0] + assert only["tool_calls"][0]["function"]["name"] == "say" + + +def test_module2_mid_episode_emits_paired_interjection_and_speech( + fixture_dataset_root: Path, tmp_path: Path +) -> None: + vlm = make_canned_responder( + { + "acknowledgement the robot": {"text": "OK."}, + "ONE realistic interruption": { + "interjection": "actually skip the dishes", + "speech": "Skipping the dishes.", + }, + }, + ) + module = InterjectionsAndSpeechModule( + vlm=vlm, + config=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.2), + seed=7, + ) + record = next(iter_episodes(fixture_dataset_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("module_2") + + interjections = [r for r in rows if r["style"] == "interjection"] + speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"] + assert len(interjections) == 1 + assert len(speeches) >= 2 # initial t=0 + one paired with the interjection + inter_t = interjections[0]["timestamp"] + assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches) + + +def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) -> None: + payload = { + "question": "How many cups?", + "answer": {"label": "cup", "count": 2, "note": "white & blue"}, + } + vlm = make_canned_responder({"frame-grounded visual question": payload}) + module = GeneralVqaModule( + vlm=vlm, + config=Module3Config(vqa_emission_hz=1.0, K=3), + seed=1, + ) + record = next(iter_episodes(single_episode_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("module_3") + user_ts = [r["timestamp"] for r in rows if r["role"] == "user" and r["style"] == "vqa"] + assistant_ts = [r["timestamp"] for r in rows if r["role"] == "assistant" and r["style"] == "vqa"] + # at most one user (vqa) per frame; same for assistant + assert len(user_ts) == len(set(user_ts)) + assert len(assistant_ts) == len(set(assistant_ts)) + # every emitted timestamp must be an exact source frame timestamp + frame_set = set(record.frame_timestamps) + for ts in user_ts + assistant_ts: + assert ts in frame_set + + +def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None: + payload = { + "question": "Where is the cup?", + "answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}, + } + vlm = make_canned_responder({"frame-grounded visual question": payload}) + module = GeneralVqaModule( + vlm=vlm, + config=Module3Config(vqa_emission_hz=1.0, K=2), + seed=2, + ) + record = next(iter_episodes(single_episode_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + rows = staging.read("module_3") + for row in rows: + if row["role"] == "assistant" and row["style"] == "vqa": + decoded = json.loads(row["content"]) + assert "detections" in decoded diff --git a/tests/annotations/test_pipeline_recipe_render.py b/tests/annotations/test_pipeline_recipe_render.py new file mode 100644 index 000000000..d881f9961 --- /dev/null +++ b/tests/annotations/test_pipeline_recipe_render.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end smoke: pipeline output → PR 1 canonical recipe rendering.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pyarrow.parquet as pq + +from lerobot.annotations.steerable_pipeline.config import ( + AnnotationPipelineConfig, + Module1Config, + Module2Config, + Module3Config, +) +from lerobot.annotations.steerable_pipeline.executor import Executor +from lerobot.annotations.steerable_pipeline.modules import ( + GeneralVqaModule, + InterjectionsAndSpeechModule, + PlanSubtasksMemoryModule, +) +from lerobot.annotations.steerable_pipeline.validator import StagingValidator +from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter +from lerobot.configs.recipe import TrainingRecipe +from lerobot.datasets.language_render import render_sample + +from ._helpers import make_canned_responder + +_RECIPE_PATH = ( + Path(__file__).resolve().parents[2] / "src" / "lerobot" / "configs" / "recipes" / "pi05_hirobot.yaml" +) + + +def _build_executor() -> Executor: + vlm = make_canned_responder( + { + "Decompose the demonstration": { + "subtasks": [ + {"text": "grasp the bottle", "start": 0.0, "end": 0.5}, + {"text": "pour into the cup", "start": 0.5, "end": 1.0}, + {"text": "place the bottle down", "start": 1.0, "end": 1.5}, + ] + }, + "write a concise hierarchical PLAN": {"plan": "1. grasp\n2. pour\n3. place"}, + "Update the memory": {"memory": "poured once"}, + "acknowledgement the robot": {"text": "Sure."}, + "ONE realistic interruption": { + "interjection": "use less water", + "speech": "Using less water.", + }, + "frame-grounded visual question": { + "question": "How many cups?", + "answer": {"label": "cup", "count": 1}, + }, + }, + ) + config = AnnotationPipelineConfig( + module_1=Module1Config(), + module_2=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.5), + module_3=Module3Config(vqa_emission_hz=1.0, K=2), + ) + return Executor( + config=config, + module_1=PlanSubtasksMemoryModule(vlm=vlm, config=config.module_1), + module_2=InterjectionsAndSpeechModule(vlm=vlm, config=config.module_2, seed=config.seed), + module_3=GeneralVqaModule(vlm=vlm, config=config.module_3, seed=config.seed), + writer=LanguageColumnsWriter(), + validator=StagingValidator(), + ) + + +def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output( + single_episode_root: Path, +) -> None: + executor = _build_executor() + summary = executor.run(single_episode_root) + # validator may emit warnings but no errors for the synthetic fixture + assert summary.validation_report.ok, summary.validation_report.summary() + + table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet") + persistent_lists = table.column("language_persistent").to_pylist() + events_lists = table.column("language_events").to_pylist() + timestamps = table.column("timestamp").to_pylist() + + recipe = TrainingRecipe.from_yaml(_RECIPE_PATH) if hasattr(TrainingRecipe, "from_yaml") else None + if recipe is None: + # PR 1 may not expose from_yaml; load via PyYAML and TrainingRecipe(**...) + import yaml + + loaded = yaml.safe_load(_RECIPE_PATH.read_text(encoding="utf-8")) + recipe = TrainingRecipe(**loaded) + + rendered_any = False + for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True): + result = render_sample( + recipe=recipe, + persistent=persistent, + events=events, + t=float(ts), + sample_idx=0, + dataset_ctx={"task": "Pour water from the bottle into the cup."}, + ) + if result is None: + continue + if result["messages"]: + rendered_any = True + assert result["target_message_indices"] + break + assert rendered_any, "PR 1 recipe rendered no messages from pipeline output" + + # Sanity: speech atom appears in events column intact + flat_events = [r for ev in events_lists for r in ev] + speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"] + assert speech_rows + say = speech_rows[0]["tool_calls"][0] + assert say["function"]["name"] == "say" + assert isinstance(say["function"]["arguments"]["text"], str) + # Tools column carries the say schema + tools = json.loads(table.column("tools").to_pylist()[0]) + assert tools and tools[0]["function"]["name"] == "say" diff --git a/tests/annotations/test_validator.py b/tests/annotations/test_validator.py new file mode 100644 index 000000000..906ef212b --- /dev/null +++ b/tests/annotations/test_validator.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Validator behavior tests.""" + +from __future__ import annotations + +import json +from pathlib import Path + +from lerobot.annotations.steerable_pipeline.reader import iter_episodes +from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging +from lerobot.annotations.steerable_pipeline.validator import StagingValidator +from lerobot.annotations.steerable_pipeline.writer import speech_atom + + +def _validate(root: Path, staging_dir: Path): + records = list(iter_episodes(root)) + return StagingValidator().validate(records, staging_dir) + + +def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + EpisodeStaging(staging_dir, 0).write( + "module_3", + [ + { + "role": "assistant", + "content": json.dumps({"label": "cup", "count": 2}, sort_keys=True), + "style": "vqa", + "timestamp": 9.999, # not on any 10 fps frame + "tool_calls": None, + } + ], + ) + report = _validate(fixture_dataset_root, staging_dir) + assert not report.ok + assert any("does not match any source frame timestamp" in e for e in report.errors) + + +def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + EpisodeStaging(staging_dir, 0).write( + "module_2", + [ + speech_atom(0.0, "Got it."), + # interjection at 0.3s with NO paired speech + { + "role": "user", + "content": "skip it", + "style": "interjection", + "timestamp": 0.3, + "tool_calls": None, + }, + ], + ) + report = _validate(fixture_dataset_root, staging_dir) + assert not report.ok + assert any("paired speech" in e for e in report.errors) + + +def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + EpisodeStaging(staging_dir, 0).write( + "module_1", + [ + { + "role": "assistant", + "content": "1. do x", + "style": "plan", + "timestamp": 0.0, + "tool_calls": None, + }, + { + "role": "assistant", + "content": "do x", + "style": "subtask", + "timestamp": 0.0, + "tool_calls": None, + }, + ], + ) + EpisodeStaging(staging_dir, 0).write( + "module_2", + [ + speech_atom(0.0, "Got it."), + speech_atom(0.4, "Replanning."), + { + "role": "user", + "content": "replan", + "style": "interjection", + "timestamp": 0.4, + "tool_calls": None, + }, + ], + ) + report = _validate(fixture_dataset_root, staging_dir) + # missing co-timestamped plan refresh at 0.4s → error + assert not report.ok + assert any("co-timestamped plan update" in e for e in report.errors) + + +def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + EpisodeStaging(staging_dir, 0).write( + "module_1", + [ + {"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None}, + ], + ) + report = _validate(fixture_dataset_root, staging_dir) + assert not report.ok + assert any("module_1 emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors) diff --git a/tests/annotations/test_writer.py b/tests/annotations/test_writer.py new file mode 100644 index 000000000..14e835bf3 --- /dev/null +++ b/tests/annotations/test_writer.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Writer correctness tests.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pyarrow.parquet as pq +import pytest + +from lerobot.annotations.steerable_pipeline.reader import iter_episodes +from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging +from lerobot.annotations.steerable_pipeline.writer import ( + LanguageColumnsWriter, + speech_atom, +) + + +def _stage_episode( + staging_dir: Path, + episode_index: int, + *, + module_1: list[dict] | None = None, + module_2: list[dict] | None = None, + module_3: list[dict] | None = None, +) -> None: + staging = EpisodeStaging(staging_dir, episode_index) + if module_1 is not None: + staging.write("module_1", module_1) + if module_2 is not None: + staging.write("module_2", module_2) + if module_3 is not None: + staging.write("module_3", module_3) + + +def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None: + """Every frame in an episode has a byte-identical persistent list.""" + staging_dir = tmp_path / "stage" + _stage_episode( + staging_dir, + 0, + module_1=[ + { + "role": "assistant", + "content": "grasp the sponge", + "style": "subtask", + "timestamp": 0.0, + "tool_calls": None, + }, + { + "role": "assistant", + "content": "1. wipe\n2. dry", + "style": "plan", + "timestamp": 0.0, + "tool_calls": None, + }, + { + "role": "assistant", + "content": "wiped the counter", + "style": "memory", + "timestamp": 0.5, + "tool_calls": None, + }, + ], + ) + records = list(iter_episodes(fixture_dataset_root)) + LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) + + table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") + persistent = table.column("language_persistent").to_pylist() + first = persistent[0] + assert first # non-empty + for row in persistent: + assert row == first, "persistent slice must be byte-identical across all frames" + + +def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + _stage_episode( + staging_dir, + 0, + module_2=[ + speech_atom(0.0, "Got it."), + { + "role": "user", + "content": "skip the dishes", + "style": "interjection", + "timestamp": 0.5, + "tool_calls": None, + }, + speech_atom(0.5, "Skipping the dishes."), + ], + ) + records = list(iter_episodes(fixture_dataset_root)) + LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) + + table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") + timestamps = table.column("timestamp").to_pylist() + events = table.column("language_events").to_pylist() + for ts, ev in zip(timestamps, events, strict=True): + if abs(ts - 0.0) < 1e-9: + assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev + elif abs(ts - 0.5) < 1e-9: + assert any(r.get("style") == "interjection" for r in ev), ev + assert any(r.get("style") is None for r in ev), ev + else: + assert ev == [] + + +def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + _stage_episode( + staging_dir, + 0, + module_1=[ + { + "role": "assistant", + "content": "do X", + "style": "subtask", + "timestamp": 0.0, + "tool_calls": None, + }, + { + "role": "assistant", + "content": "1. do X", + "style": "plan", + "timestamp": 0.0, + "tool_calls": None, + }, + { + "role": "assistant", + "content": "did X", + "style": "memory", + "timestamp": 0.3, + "tool_calls": None, + }, + ], + module_2=[ + speech_atom(0.0, "OK"), + { + "role": "user", + "content": "wait", + "style": "interjection", + "timestamp": 0.2, + "tool_calls": None, + }, + speech_atom(0.2, "Waiting"), + ], + module_3=[ + { + "role": "user", + "content": "where is the cup?", + "style": "vqa", + "timestamp": 0.4, + "tool_calls": None, + }, + { + "role": "assistant", + "content": json.dumps( + {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]}, + sort_keys=True, + ), + "style": "vqa", + "timestamp": 0.4, + "tool_calls": None, + }, + ], + ) + records = list(iter_episodes(fixture_dataset_root)) + LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) + table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") + + persistent = table.column("language_persistent").to_pylist()[0] + persistent_styles = {r["style"] for r in persistent} + assert persistent_styles == {"subtask", "plan", "memory"} + + all_events = [r for ev in table.column("language_events").to_pylist() for r in ev] + event_styles = {r.get("style") for r in all_events} + assert event_styles == {None, "interjection", "vqa"} + + +def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + _stage_episode( + staging_dir, + 0, + module_1=[ + { + "role": "assistant", + "content": "do X", + "style": "subtask", + "timestamp": 0.0, + "tool_calls": None, + }, + ], + ) + records = list(iter_episodes(fixture_dataset_root)) + writer = LanguageColumnsWriter() + writer.write_all(records, staging_dir, fixture_dataset_root) + + path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet" + table_a = pq.read_table(path) + assert "subtask_index" not in table_a.column_names + assert "language_persistent" in table_a.column_names + assert "language_events" in table_a.column_names + assert "tools" in table_a.column_names + + # second pass — must produce identical bytes for the language columns + records_again = list(iter_episodes(fixture_dataset_root)) + writer.write_all(records_again, staging_dir, fixture_dataset_root) + table_b = pq.read_table(path) + assert ( + table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist() + ) + assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist() + + +def test_writer_normalize_rejects_misrouted_persistent_style() -> None: + """``_normalize_persistent_row`` must reject any non-persistent style.""" + from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row + + with pytest.raises(ValueError, match="non-persistent style"): + _normalize_persistent_row( + {"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None} + ) + + +def test_writer_normalize_rejects_misrouted_event_style() -> None: + """``_normalize_event_row`` must reject any persistent style.""" + from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row + + with pytest.raises(ValueError): + _normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None}) + + +def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path, tmp_path: Path) -> None: + staging_dir = tmp_path / "stage" + _stage_episode( + staging_dir, + 0, + module_1=[ + {"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None} + ], + ) + records = list(iter_episodes(fixture_dataset_root)) + LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) + table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") + tools = table.column("tools").to_pylist() + assert tools, "tools column missing" + decoded = json.loads(tools[0]) + assert isinstance(decoded, list) + assert len(decoded) == 1 + assert decoded[0]["function"]["name"] == "say" + params = decoded[0]["function"]["parameters"] + assert params["properties"]["text"]["type"] == "string" + + +def test_speech_atom_shape_matches_plan_spec() -> None: + atom = speech_atom(2.5, "I'm cleaning up!") + assert atom["role"] == "assistant" + assert atom["style"] is None + assert atom["content"] is None + assert atom["timestamp"] == 2.5 + assert isinstance(atom["tool_calls"], list) + call = atom["tool_calls"][0] + assert call["type"] == "function" + assert call["function"]["name"] == "say" + assert call["function"]["arguments"]["text"] == "I'm cleaning up!"