feat: language annotation pipeline (PR 2/3)

Adds the steerable annotation pipeline (`lerobot-annotate`) that populates
the `language_persistent` and `language_events` columns introduced in
PR 1 directly into `data/chunk-*/file-*.parquet`. No flavor namespace,
no sidecar tree.

Modules produced:
- Module 1 (plan_subtasks_memory): Pi0.7-style subtasks, plan (init +
  refresh on interjection), MEM-style memory at subtask boundaries.
- Module 2 (interjections_and_speech): t=0 speech-only acknowledgement,
  mid-episode paired interjection + speech tool-call atom.
- Module 3 (general_vqa): bbox/keypoint/count/attribute/spatial pairs at
  configurable cadence with one-retry JSON validation.

Writer enforces: per-episode persistent identity, exact-frame event
timestamps, column routing per `column_for_style`, dataset-level `tools`
column with the `say` schema, drops legacy `subtask_index`. Validator
runs against staged JSONL artifacts before the writer rewrites parquet.

Adds `lerobot-annotate` console script, `annotations` extra (datatrove +
optional vllm), `make annotation-e2e` opt-in smoke target, and
`docs/source/annotation_pipeline.mdx`.

Branched from PR 1 (`feat/language-columns`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-27 16:22:51 +02:00
parent 1ca38d9748
commit 785cee429e
33 changed files with 3409 additions and 0 deletions
+6
View File
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
--env.episode_length=5 \
--eval.n_episodes=1 \
--eval.batch_size=1
# E2E annotation pipeline smoke test against a tiny in-memory fixture
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
# backend, so it does not require a real model checkpoint or GPU.
annotation-e2e:
uv run python -m tests.annotations.run_e2e_smoke
+2
View File
@@ -33,6 +33,8 @@
title: Using the Dataset Tools
- local: language_and_recipes
title: Language Columns and Recipes
- local: annotation_pipeline
title: Annotation Pipeline
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
+133
View File
@@ -0,0 +1,133 @@
# Annotation Pipeline
`lerobot-annotate` populates the two language columns introduced by the
[Language Columns and Recipes](./language_and_recipes) page —
`language_persistent` and `language_events` — directly into
`data/chunk-*/file-*.parquet`. There is no flavor namespace and no sidecar
file tree: multiple revisions of a dataset mean multiple dataset copies.
## What the pipeline produces
Three modules write into a per-episode staging tree, then a single writer
rewrites the data shards in place:
| Style / atom | Column | Module |
| ------------------------------------------- | --------------------- | -------- |
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | Module 1 |
| `plan` (initial + refresh on interjection) | `language_persistent` | Module 1 |
| `memory` (MEM-style compression) | `language_persistent` | Module 1 |
| `interjection` | `language_events` | Module 2 |
| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 |
| `vqa` (user / assistant pair) | `language_events` | Module 3 |
The writer also adds a dataset-level `tools` column carrying the JSON schema
for the `say` tool call, and drops the legacy `subtask_index` column.
## How to run it locally or on SLURM
Install the extra and invoke the console script:
```bash
uv sync --extra annotations
uv run lerobot-annotate \
--root=/path/to/dataset \
--vlm.backend=transformers \
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
```
The executor picks `LocalPipelineExecutor` for small datasets and
`SlurmPipelineExecutor` for large ones based on
`--executor.auto_threshold` (default 32 episodes). Force local with
`--executor.force_local=true`. SLURM jobs honour `--executor.slurm_partition`,
`--executor.slurm_gpus`, and `--executor.slurm_time`.
## Style-to-recipe consumer mapping
The pipeline produces exactly the styles consumed by
`src/lerobot/configs/recipes/pi05_hirobot.yaml`:
- `low_level_execution`, `high_level_subtask`, `memory_update` consume
`subtask`/`plan`/`memory` from `language_persistent`.
- `user_interjection_response` consumes `interjection` events plus the
paired speech atom (merged into one assistant target turn via
`tool_calls_from`) and the same-timestamp `plan` refresh.
- `ask_vqa` consumes the `(vqa, user)` and `(vqa, assistant)` pairs from
`language_events`.
## Why the design is scoped to the canonical recipe
Two things drive the scope:
1. **Persistent state vs exact-event split.** Persistent rows (`subtask`,
`plan`, `memory`) broadcast per episode and answer "what state is in
force at this frame?". Event rows (`interjection`, `vqa`, speech) only
appear on the exact frame whose timestamp matches the emission. The
pipeline writes timestamps taken straight from the source parquet — no
floating-point recomputation.
2. **One Qwen-VL pass.** All three modules share a single VLM client
(vLLM if available, transformers fallback) so the cost is one model
load per dataset, not three.
## Module independence and staged reruns
Each module writes its raw output to
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
prompt iteration cheap — re-running one module overwrites only its own
JSONL file before the writer composes the final parquet. Modules can be
disabled via `--module_1.enabled=false` (and similarly for 2 and 3) to
test them in isolation.
## Validation/report checks before final write
Before the writer runs, `StagingValidator` checks:
- exact frame-timestamp alignment for every event row;
- no orphan speech / interjection pairs;
- `plan` is refreshed at every interjection timestamp;
- `memory` rows fall on subtask boundaries (warning, not error);
- VQA assistant `content` parses as JSON in one of the
bbox / keypoint / count / attribute / spatial shapes;
- every row routes to the column dictated by `column_for_style(style)`.
Errors abort the writer (`--skip_validation=true` overrides for debugging).
## Paper inspirations per module
- **Module 1 — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
atom granularity ("pick up one piece of lettuce", "place bowl to box");
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
what" detail.
- **Module 1 — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
compression directive: keep only minimal relevant information; functional
outcomes preserved, specific attributes dropped.
- **Module 2 — interjections.** Hi Robot scenario taxonomy: negative task,
situated correction, specific constraint, preference. Speech is a
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
arguments:{text:...}}}]`).
- **Module 3 — VQA.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
keypoints) and Steerable Policies' multi-abstraction grounding.
Future maintainers should adjust the prompt templates in
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
references rather than rewriting from scratch.
## Compute and list-size estimates
Per episode, the pipeline issues O(`max_steps`) Module 1 calls,
O(`max_interjections_per_episode`) Module 2 calls, and
O(`vqa_emission_hz × episode_seconds`) Module 3 calls. With defaults
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
KB at most (parquet dictionary-encodes one entry per episode);
`language_events` is empty on most frames and is bounded by the number of
emissions, not `num_frames × num_emissions`.
## Reproducibility via seed and prompt hashes
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
timestamps and VQA question types. Combined with the deterministic prompt
templates checked into `prompts/`, two runs at the same seed against the
same dataset and the same model checkpoint produce byte-identical staging
artifacts. Prompt edits are recorded by file hash; future tooling can pin
expected `(seed, prompt_hash)` pairs into the dataset card.
+10
View File
@@ -200,6 +200,15 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Annotation pipeline (lerobot-annotate). datatrove is mandatory; vllm is
# the preferred backend on Linux, with a transformers fallback elsewhere.
annotations = [
"lerobot[dataset]",
"lerobot[transformers-dep]",
"datatrove>=0.4.0,<2.0.0",
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
@@ -289,6 +298,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
+15
View File
@@ -0,0 +1,15 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Steerable annotation pipeline producing ``language_persistent`` and
``language_events`` columns for LeRobot datasets.
The pipeline is decomposed into three independently runnable modules whose
outputs are staged per-episode before a final parquet rewrite:
- :mod:`.modules.plan_subtasks_memory` (Module 1) — persistent styles
- :mod:`.modules.interjections_and_speech` (Module 2) — event styles + speech
- :mod:`.modules.general_vqa` (Module 3) — event-style VQA pairs
"""
from .config import AnnotationPipelineConfig
from .validator import StagingValidator, ValidationReport
from .writer import LanguageColumnsWriter
__all__ = [
"AnnotationPipelineConfig",
"LanguageColumnsWriter",
"StagingValidator",
"ValidationReport",
]
@@ -0,0 +1,108 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
@dataclass
class Module1Config:
"""Module 1 hyperparameters: plan + subtasks + memory."""
enabled: bool = True
keyframes_per_episode: int = 8
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
@dataclass
class Module2Config:
"""Module 2 hyperparameters: interjections + paired speech."""
enabled: bool = True
max_interjections_per_episode: int = 1
interjection_min_t: float = 2.0
@dataclass
class Module3Config:
"""Module 3 hyperparameters: general VQA."""
enabled: bool = True
vqa_emission_hz: float = 1.0
K: int = 3
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
@dataclass
class VlmConfig:
"""Shared Qwen-VL client configuration."""
backend: Literal["vllm", "transformers", "stub"] = "transformers"
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
max_new_tokens: int = 512
temperature: float = 0.2
json_mode: bool = True
batch_size: int = 4
@dataclass
class ExecutorConfig:
"""Executor selection and SLURM hyperparameters."""
auto_threshold: int = 32
force_local: bool = False
slurm_partition: str | None = None
slurm_gpus: int = 1
slurm_time: str = "06:00:00"
workers: int = 1
@dataclass
class AnnotationPipelineConfig:
"""Top-level config for ``lerobot-annotate``.
Mirrors the structure of :class:`lerobot.configs.train.TrainPipelineConfig`:
a draccus-parsed dataclass that contains nested per-module sub-configs and
leaves the dataset, executor, and VLM choices independently knobbable.
Output is always in-place: the writer rewrites ``data/chunk-*/file-*.parquet``
in place. Multiple revisions of the same dataset live in separate copies.
"""
repo_id: str | None = None
root: Path | None = None
staging_dir: Path | None = None
"""If unset, defaults to ``<root>/.annotate_staging/``."""
seed: int = 1729
module_1: Module1Config = field(default_factory=Module1Config)
module_2: Module2Config = field(default_factory=Module2Config)
module_3: Module3Config = field(default_factory=Module3Config)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None
def resolved_staging_dir(self, root: Path) -> Path:
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
@@ -0,0 +1,163 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Executor selection: local vs SLURM via datatrove.
The executor plans **four phases** with the dependency order from the plan:
phase 1: Module 1 (plan + subtasks + memory)
phase 2: Module 2 (interjections + speech)
phase 3: Module 1 plan-update pass — re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: Module 3 (VQA)
phase 5: validator
phase 6: writer
Phase 3 is why ``executor.py`` documents the dependency: Module 1 must be
re-entered after Module 2 to refresh ``plan`` rows at interjection times.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .config import AnnotationPipelineConfig, ExecutorConfig
from .reader import EpisodeRecord, iter_episodes
from .staging import EpisodeStaging
from .validator import StagingValidator
from .writer import LanguageColumnsWriter
logger = logging.getLogger(__name__)
@dataclass
class PhaseResult:
"""Summary of one pipeline phase across all episodes."""
name: str
episodes_processed: int
episodes_skipped: int
@dataclass
class PipelineRunSummary:
"""Aggregated result returned by :meth:`Executor.run`."""
phases: list[PhaseResult]
written_paths: list[Path]
validation_report: Any # ValidationReport, kept Any to avoid import cycle
def select_executor_class(num_episodes: int, config: ExecutorConfig) -> str:
"""Return ``"local"`` or ``"slurm"`` based on the threshold.
The plan's "executor selection threshold" lives in
:class:`ExecutorConfig.auto_threshold`. ``force_local`` always wins.
"""
if config.force_local:
return "local"
return "local" if num_episodes <= config.auto_threshold else "slurm"
@dataclass
class Executor:
"""Run all four phases over a dataset root.
The executor is intentionally framework-agnostic: by default it runs the
phases inline (suitable for tests, small datasets, and the CLI's
``--force-local`` mode). It will optionally hand off to datatrove's
:class:`LocalPipelineExecutor` or :class:`SlurmPipelineExecutor` when those
are installed and the dataset is large enough to benefit from them.
Tests construct the executor directly with stub modules.
"""
config: AnnotationPipelineConfig
module_1: Any # PlanSubtasksMemoryModule
module_2: Any # InterjectionsAndSpeechModule
module_3: Any # GeneralVqaModule
writer: LanguageColumnsWriter
validator: StagingValidator
def run(self, root: Path) -> PipelineRunSummary:
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
n = len(records)
if n == 0:
raise ValueError(f"No episodes found under {root}/data/")
executor_kind = select_executor_class(n, self.config.executor)
logger.info("annotate: %d episodes; executor=%s", n, executor_kind)
staging_dir = self.config.resolved_staging_dir(root)
staging_dir.mkdir(parents=True, exist_ok=True)
phases: list[PhaseResult] = []
# Phase 1: Module 1 (plan + subtasks + memory)
phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1))
# Phase 2: Module 2 (interjections + speech)
phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2))
# Phase 3: Module 1 plan-update pass at interjection timestamps.
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: Module 3 (VQA)
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
report = self.validator.validate(records, staging_dir)
if not report.ok and not self.config.skip_validation:
raise RuntimeError(f"Staging validation failed: {report.summary()}")
written = self.writer.write_all(records, staging_dir, root)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
def _run_module_phase(
self,
name: str,
records: list[EpisodeRecord],
staging_dir: Path,
module: Any,
) -> PhaseResult:
if not module.enabled:
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging)
processed += 1
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
def _run_plan_update_phase(self, records: list[EpisodeRecord], staging_dir: Path) -> PhaseResult:
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
therefore calls back into Module 1 with the interjection timestamps so
Module 1's existing prompt path is reused.
"""
if not self.module_1.enabled or not self.module_2.enabled:
return PhaseResult(
name="module_1_plan_update", episodes_processed=0, episodes_skipped=len(records)
)
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
interjection_times = [
row["timestamp"] for row in staging.read("module_2") if row.get("style") == "interjection"
]
if interjection_times:
self.module_1.run_plan_updates(record, staging, interjection_times)
processed += 1
return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0)
@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .general_vqa import GeneralVqaModule
from .interjections_and_speech import InterjectionsAndSpeechModule
from .plan_subtasks_memory import PlanSubtasksMemoryModule
__all__ = [
"GeneralVqaModule",
"InterjectionsAndSpeechModule",
"PlanSubtasksMemoryModule",
]
@@ -0,0 +1,146 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 3: general VQA at a timed cadence.
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
emission so each frame gets at most one ``(vqa, user)`` and one
``(vqa, assistant)`` pair — keeps the resolver contract scalar.
Question types covered (per the plan's Module 3 table): bbox, keypoint,
count, attribute, spatial. The assistant's ``content`` is a JSON string
whose schema depends on the question type. Malformed JSON triggers one
retry inside :meth:`VlmClient.generate_json`.
"""
from __future__ import annotations
import json
import random
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from ..config import Module3Config
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..validator import classify_vqa_answer
from ..vlm_client import VlmClient
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
"""Return the relative frame indices to anchor VQA emissions to.
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
consecutive frames starting at the tick. Ticks fall on the nearest
available source frame timestamp.
"""
if hz <= 0 or k <= 0 or not frame_timestamps:
return []
t0 = frame_timestamps[0]
t_last = frame_timestamps[-1]
period = 1.0 / hz
indices: list[int] = []
t = t0
while t <= t_last + 1e-9:
# find the index of the nearest frame to t
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
for offset in range(k):
j = nearest_i + offset
if j >= len(frame_timestamps):
break
if not indices or indices[-1] != j:
indices.append(j)
t += period
# dedupe while preserving order
seen: set[int] = set()
deduped: list[int] = []
for i in indices:
if i in seen:
continue
seen.add(i)
deduped.append(i)
return deduped
@dataclass
class GeneralVqaModule:
"""Emit grounded VQA pairs at a timed cadence."""
vlm: VlmClient
config: Module3Config
seed: int = 1729
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
if not record.frame_timestamps:
staging.write("module_3", [])
return
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
)
rows: list[dict[str, Any]] = []
for idx in anchor_idx:
ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types)
qa = self._generate_one(record, qtype)
if qa is None:
continue
question, answer = qa
rows.append(
{
"role": "user",
"content": question,
"style": "vqa",
"timestamp": ts,
"tool_calls": None,
}
)
rows.append(
{
"role": "assistant",
"content": json.dumps(answer, sort_keys=True),
"style": "vqa",
"timestamp": ts,
"tool_calls": None,
}
)
staging.write("module_3", rows)
def _generate_one(self, record: EpisodeRecord, question_type: str) -> tuple[str, dict[str, Any]] | None:
prompt = load_prompt("module_3_vqa").format(
episode_task=record.episode_task,
question_type=question_type,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
return None
question = result.get("question")
answer = result.get("answer")
if not isinstance(question, str) or not question.strip():
return None
if not isinstance(answer, dict):
return None
# The validator will enforce shape; here we just sanity-check that the
# answer matches *some* known shape so we can drop garbage early.
if classify_vqa_answer(answer) is None:
return None
return question.strip(), answer
@@ -0,0 +1,129 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 2: interjections + paired speech (EVENT styles + speech atoms).
Two sub-passes:
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
canonical task). No interjection row — the canonical task is already the
user utterance from ``meta/tasks.parquet``.
2. For mid-episode interruptions, emit a co-timestamped pair:
{role:user, style:interjection, content:<text>}
speech atom (role:assistant, style:None, tool_calls=[say(...)])
Both rows go in ``language_events`` at the same timestamp.
Module 1's :meth:`run_plan_updates` reuses Module 2's interjection
timestamps to refresh the ``plan`` row at the same instant.
"""
from __future__ import annotations
import random
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from ..config import Module2Config
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
from ..writer import speech_atom
def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
if not frame_timestamps:
return float(t)
return float(min(frame_timestamps, key=lambda f: abs(f - t)))
@dataclass
class InterjectionsAndSpeechModule:
"""Generate task-start speech and mid-episode interjection/speech pairs."""
vlm: VlmClient
config: Module2Config
seed: int = 1729
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
if record.frame_timestamps:
t0 = float(record.frame_timestamps[0])
initial = self._initial_speech(record)
if initial:
rows.append(speech_atom(t0, initial))
rows.extend(self._mid_episode_interjections(record))
staging.write("module_2", rows)
def _initial_speech(self, record: EpisodeRecord) -> str | None:
prompt = load_prompt("module_2_initial_speech").format(
episode_task=record.episode_task,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("text"), str):
text = result["text"].strip()
if text:
return text
return None
def _mid_episode_interjections(self, record: EpisodeRecord) -> list[dict[str, Any]]:
if self.config.max_interjections_per_episode <= 0:
return []
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
candidate_ts = [t for t in record.frame_timestamps if t >= self.config.interjection_min_t]
if not candidate_ts:
return []
n = min(self.config.max_interjections_per_episode, len(candidate_ts) // 4)
if n <= 0:
return []
chosen = sorted(rng.sample(candidate_ts, n))
out: list[dict[str, Any]] = []
for t in chosen:
t_snap = _snap_to_frame(t, record.frame_timestamps)
current_subtask = record.episode_task
prompt = load_prompt("module_2_interjection").format(
episode_task=record.episode_task,
current_subtask=current_subtask,
timestamp=t_snap,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
continue
interjection_text = result.get("interjection")
speech_text = result.get("speech")
if not isinstance(interjection_text, str) or not interjection_text.strip():
continue
if not isinstance(speech_text, str) or not speech_text.strip():
continue
out.append(
{
"role": "user",
"content": interjection_text.strip(),
"style": "interjection",
"timestamp": t_snap,
"tool_calls": None,
}
)
out.append(speech_atom(t_snap, speech_text.strip()))
return out
@@ -0,0 +1,226 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 1: subtask decomposition + plan + memory (PERSISTENT styles)."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from ..config import Module1Config
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, keyframe_indices
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
"""Snap an arbitrary float to the nearest exact source frame timestamp."""
if not frame_timestamps:
return float(t)
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
return float(nearest)
@dataclass
class PlanSubtasksMemoryModule:
"""Generate subtask spans, plan, and memory rows.
All output is persistent (lives in ``language_persistent``):
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
(snapped to an exact frame).
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
timestamp via :meth:`run_plan_updates` (called by the executor after
Module 2 completes).
- ``memory`` rows: emitted at each subtask boundary (= subtask start
timestamp from the second subtask onward).
"""
vlm: VlmClient
config: Module1Config
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
subtask_spans = self._generate_subtasks(record)
# subtask rows
for span in subtask_spans:
rows.append(
{
"role": "assistant",
"content": span["text"],
"style": "subtask",
"timestamp": _snap_to_frame(span["start"], record.frame_timestamps),
"tool_calls": None,
}
)
# plan row at t=0
plan_text = self._generate_plan(record, subtask_spans)
if plan_text is not None:
t0 = record.frame_timestamps[0] if record.frame_timestamps else 0.0
rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": float(t0),
"tool_calls": None,
}
)
# memory rows at every subtask boundary except the very first start
prior_memory = ""
for i, span in enumerate(subtask_spans[1:], start=1):
completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory(record, prior_memory, completed, remaining)
if mem_text:
ts = _snap_to_frame(span["start"], record.frame_timestamps)
rows.append(
{
"role": "assistant",
"content": mem_text,
"style": "memory",
"timestamp": ts,
"tool_calls": None,
}
)
prior_memory = mem_text
staging.write("module_1", rows)
def run_plan_updates(
self,
record: EpisodeRecord,
staging: EpisodeStaging,
interjection_times: Sequence[float],
) -> None:
"""Append additional ``plan`` rows at every interjection timestamp."""
existing = staging.read("module_1")
spans = self._reconstruct_subtasks_from_rows(existing)
new_rows = list(existing)
for raw_t in interjection_times:
t = _snap_to_frame(raw_t, record.frame_timestamps)
plan_text = self._generate_plan(record, spans, refresh_t=t)
if plan_text is not None:
new_rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": t,
"tool_calls": None,
}
)
staging.write("module_1", new_rows)
@staticmethod
def _reconstruct_subtasks_from_rows(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
out = []
last_t: float | None = None
for row in sorted(
(r for r in rows if r.get("style") == "subtask"),
key=lambda r: float(r["timestamp"]),
):
t = float(row["timestamp"])
if last_t is not None:
out[-1]["end"] = t
out.append({"text": row.get("content") or "", "start": t, "end": t})
last_t = t
return out
def _generate_subtasks(self, record: EpisodeRecord) -> list[dict[str, Any]]:
if record.row_count == 0 or not record.frame_timestamps:
return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
keyframe_local = keyframe_indices(record, self.config.keyframes_per_episode)
prompt = load_prompt("module_1_subtasks").format(
episode_task=record.episode_task,
num_keyframes=len(keyframe_local),
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
spans = result.get("subtasks") if isinstance(result, dict) else None
if not spans:
return []
# clamp to [t0, t_last] and sort
t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1]
cleaned: list[dict[str, Any]] = []
for span in spans:
try:
start = float(span["start"])
end = float(span["end"])
text = str(span["text"]).strip()
except (KeyError, ValueError, TypeError):
continue
start = max(t0, min(start, t_last))
end = max(t0, min(end, t_last))
if end < start:
start, end = end, start
if not text:
continue
cleaned.append({"text": text, "start": start, "end": end})
cleaned.sort(key=lambda s: s["start"])
return cleaned
def _generate_plan(
self,
record: EpisodeRecord,
subtask_spans: Sequence[dict[str, Any]],
*,
refresh_t: float | None = None,
) -> str | None:
if not subtask_spans:
return None
subtasks_text = "\n".join(f"- {s['text']}" for s in subtask_spans)
prompt = load_prompt("module_1_plan").format(
episode_task=record.episode_task,
subtasks_text=subtasks_text,
plan_max_steps=self.config.plan_max_steps,
)
if refresh_t is not None:
prompt += f"\n\n(This is a plan refresh after a user interjection at t={refresh_t:.2f}s.)\n"
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("plan"), str):
return result["plan"].strip()
return None
def _generate_memory(
self,
record: EpisodeRecord,
prior_memory: str,
completed: str,
remaining: Sequence[str],
) -> str:
prompt = load_prompt("module_1_memory").format(
episode_task=record.episode_task,
prior_memory=prior_memory or "(none)",
completed_subtask=completed,
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("memory"), str):
return result["memory"].strip()
return ""
@@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompt templates loaded as plain text.
One file per use site. Templates use ``str.format(**vars)`` substitution; we
intentionally avoid jinja2 here so the templates remain inspectable in
plain editors and roundtrip cleanly through ``ruff format``.
"""
from __future__ import annotations
from pathlib import Path
_DIR = Path(__file__).parent
def load(name: str) -> str:
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
path = _DIR / f"{name}.txt"
return path.read_text(encoding="utf-8")
@@ -0,0 +1,25 @@
You are updating the robot's compressed semantic memory at the boundary of
a completed subtask.
Reference (verbatim from MEM, Torne 2026):
"Remove or compress information in the language memory whenever
appropriate. Keep ONLY the minimal set of relevant information for future
task execution. Specific object attributes (colors, precise quantities of
each item) get discarded when their details won't affect subsequent
actions. Functional outcomes (where items went, how many) are preserved."
Concrete example from MEM:
Before: "I put a light green bowl, a dark blue bowl and a bright yellow
bowl into the top right cabinet"
After: "I placed three bowls in the top right cabinet"
Episode task: "{episode_task}"
Previous memory: {prior_memory}
Just-completed subtask: "{completed_subtask}"
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
Update the memory. Drop irrelevant detail. Compress completed steps.
Keep WHAT happened, drop HOW. Shorter is better.
Output strictly valid JSON:
{{ "memory": "<one or two short sentences>" }}
@@ -0,0 +1,18 @@
You are the high-level planner for a robot demonstrating: "{episode_task}".
Given the subtask decomposition below, write a concise hierarchical PLAN
the robot should follow. Format the plan as a numbered list, one line per
high-level step. The plan describes the full task; subtasks are the atomic
skills used to execute it.
Subtasks for context:
{subtasks_text}
Authoring rules:
- 3 to {plan_max_steps} steps.
- Each step describes one logical chunk of the task, not one motion.
- Steps must be in execution order.
- Plain prose, no JSON, no markdown headers.
Output strictly valid JSON:
{{ "plan": "1. ...\n2. ...\n3. ..." }}
@@ -0,0 +1,31 @@
You are labeling a teleoperated robot demonstration.
The user originally asked: "{episode_task}"
You will be shown {num_keyframes} keyframes spaced evenly across the
episode. Decompose the demonstration into a list of consecutive atomic
subtasks the robot performs.
Authoring rules — based on Hi Robot (Shi 2025) atom granularity and Pi0.7
(Physical Intelligence 2025) "how, not what" detail:
- Each subtask is one atomic skill the low-level policy can execute, e.g.
"pick up one piece of lettuce", "place the bowl into the box",
"move the right arm to the left".
- Capture HOW the subtask is performed, not only WHAT — e.g. prefer
"grasp the handle of the sponge with the left hand" to "pick up the
sponge".
- Subtasks are non-overlapping and must cover the full episode in order.
- Each subtask spans at least {min_subtask_seconds} seconds.
- Do not exceed {max_steps} subtasks total.
- Every subtask's [start_time, end_time] must lie within
[0.0, {episode_duration}] seconds.
Output strictly valid JSON of shape:
{{
"subtasks": [
{{"text": "<how-not-what>", "start": <float>, "end": <float>}},
...
]
}}
@@ -0,0 +1,10 @@
The user just asked the robot: "{episode_task}".
Generate a short verbal acknowledgement the robot would speak back before
beginning the task. Style: confident, friendly, single short sentence.
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
"OK, starting with the sponge.", "Got it.".
Output strictly valid JSON:
{{ "text": "<the spoken acknowledgement>" }}
@@ -0,0 +1,27 @@
You are simulating a user mid-episode interruption for a robot doing:
"{episode_task}".
Synthesize ONE realistic interruption the user might say at this moment in
the demonstration, plus the robot's verbal acknowledgement.
Context (Hi Robot, Shi 2025) — interjections fall into one of these
scenario types:
- negative task: "actually skip X"
- situated correction: "that's not trash"
- specific constraint: "use less salt"
- preference: "could you also do Y"
Interruption rules:
- Must be plausible given the current subtask context.
- Must change the plan in a non-trivial way (a new constraint, skipped
step, or correction).
- One sentence each.
Current subtask context: {current_subtask}
Time into episode: {timestamp:.2f}s
Output strictly valid JSON:
{{
"interjection": "<single sentence the user says>",
"speech": "<single sentence the robot speaks back>"
}}
@@ -0,0 +1,32 @@
You are generating a frame-grounded visual question/answer pair for
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
Policies — both train policies on grounded features such as bounding box
pixel coordinates, keypoints, counts, attributes, and spatial relations.
The frame shows a robot working on: "{episode_task}".
Question types and the EXACT answer JSON shape required for each:
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}}, ...]}}
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
ECoT example: "a white cup [124, 25, 176, 113]".
keypoint => {{"label": "<point>", "point_format": "xy",
"point": [x, y]}}
count => {{"label": "<obj>", "count": <int>,
"note": "<optional short note>"}}
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
"value": "<observed value>"}}
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
"above|below|near>", "object": "<obj>"}}
Generate a question of type "{question_type}". Output strictly valid JSON:
{{
"question": "<short, frame-grounded question>",
"answer": <object whose shape matches the schema above>
}}
@@ -0,0 +1,219 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datatrove-shaped reader.
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
episode containing:
- ``episode_index``: int
- ``frame_timestamps``: tuple[float, ...]
- ``frame_indices``: tuple[int, ...]
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
- ``data_path``: pathlib.Path of the source parquet shard
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
This shape lets each module operate per-episode without loading all parquet
rows into memory at once. It deliberately does not depend on datatrove —
datatrove integration wraps this generator inside a ``PipelineStep`` in
:mod:`.executor`.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pyarrow.parquet as pq
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
@dataclass
class EpisodeRecord:
"""Per-episode record yielded by the reader."""
episode_index: int
episode_task: str
frame_timestamps: tuple[float, ...]
frame_indices: tuple[int, ...]
data_path: Path
row_offset: int # row offset within the parquet file where this episode starts
row_count: int # number of rows for this episode
def frames_df(self): # type: ignore[no-untyped-def]
"""Lazy-load the pandas slice for this episode."""
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
table = pq.read_table(self.data_path)
df: pd.DataFrame = table.to_pandas()
slice_ = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(drop=True)
return slice_
def _load_tasks_lookup(root: Path) -> dict[int, str]:
tasks_path = root / DEFAULT_TASKS_PATH
if not tasks_path.exists():
return {}
table = pq.read_table(tasks_path)
cols = {name: table.column(name).to_pylist() for name in table.column_names}
if "task_index" in cols and "task" in cols:
return dict(zip(cols["task_index"], cols["task"], strict=True))
raise ValueError(f"meta/tasks.parquet at {tasks_path} missing 'task_index' or 'task'")
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
Episodes are yielded in ascending ``episode_index`` order. The reader does
not assume a specific chunk/file layout: it scans every ``*.parquet``
under ``data/`` and groups by ``episode_index``.
"""
tasks = _load_tasks_lookup(root)
data_dir = root / "data"
parquet_files = sorted(data_dir.rglob("*.parquet"))
only_set = set(only_episodes) if only_episodes is not None else None
for path in parquet_files:
yield from _iter_one_path(path, tasks, only_set)
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
table = pq.read_table(path)
names = table.column_names
if "episode_index" not in names:
return
episode_col = table.column("episode_index").to_pylist()
timestamp_col = (
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
)
frame_col = (
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
)
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
def _build(
ep: int,
start: int,
end: int,
task_idx: int | None,
ts_buf: list[float],
fi_buf: list[int],
) -> EpisodeRecord | None:
if only_set is not None and ep not in only_set:
return None
task = tasks.get(task_idx, "") if task_idx is not None else ""
return EpisodeRecord(
episode_index=ep,
episode_task=task,
frame_timestamps=tuple(ts_buf),
frame_indices=tuple(fi_buf),
data_path=path,
row_offset=start,
row_count=end - start,
)
cur_ep: int | None = None
start_offset = 0
ts_buf: list[float] = []
fi_buf: list[int] = []
cur_task_idx: int | None = None
for i, ep in enumerate(episode_col):
if cur_ep is None:
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
continue
if ep != cur_ep:
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
else:
ts_buf.append(timestamp_col[i])
fi_buf.append(frame_col[i])
if cur_ep is not None:
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
def gather_data_paths(root: Path) -> list[Path]:
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
return sorted((root / "data").rglob("*.parquet"))
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
table = pq.read_table(path, columns=["episode_index"])
episode_col = table.column("episode_index").to_pylist()
out: dict[int, tuple[int, int]] = {}
cur_ep: int | None = None
start = 0
for i, ep in enumerate(episode_col):
if cur_ep is None:
cur_ep = ep
start = i
continue
if ep != cur_ep:
out[cur_ep] = (start, i - start)
cur_ep = ep
start = i
if cur_ep is not None:
out[cur_ep] = (start, len(episode_col) - start)
return out
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
n = record.row_count
if k <= 0 or n == 0:
return []
if k >= n:
return list(range(n))
step = (n - 1) / (k - 1) if k > 1 else 0.0
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
for path in gather_data_paths(root):
offsets = episode_offsets_per_path(path)
if episode_index in offsets:
start, count = offsets[episode_index]
return path, start, count
return None
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
found = lookup_data_path(root, episode_index)
if found is None:
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
path, start, count = found
table = pq.read_table(path, columns=["timestamp"])
timestamps = table.column("timestamp").to_pylist()[start : start + count]
return path, [float(t) for t in timestamps]
@@ -0,0 +1,98 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Per-episode staging.
Each module writes its raw output as a JSONL file under
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
staging tree and partitions rows into the two language columns.
JSONL is preferred over parquet here because the staging artifact is meant to
be human-inspectable, easy to diff between prompt iterations, and trivially
appended to. The final dataset format is parquet; staging is just an
intermediate.
"""
from __future__ import annotations
import json
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
ModuleName = str
_MODULES: tuple[ModuleName, ...] = (
"module_1",
"module_2",
"module_3",
)
@dataclass
class EpisodeStaging:
"""Filesystem layout for a single episode's staged module outputs."""
root: Path
episode_index: int
@property
def episode_dir(self) -> Path:
return self.root / f"episode_{self.episode_index:06d}"
def path_for(self, module: ModuleName) -> Path:
if module not in _MODULES:
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
return self.episode_dir / f"{module}.jsonl"
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
path = self.path_for(module)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
f.write("\n")
return path
def read(self, module: ModuleName) -> list[dict[str, Any]]:
path = self.path_for(module)
if not path.exists():
return []
out: list[dict[str, Any]] = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
out.append(json.loads(line))
return out
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
return {m: self.read(m) for m in _MODULES}
def has(self, module: ModuleName) -> bool:
return self.path_for(module).exists()
def iter_staged_episodes(root: Path) -> Iterator[int]:
"""Yield episode indices for which any staging artifact exists."""
if not root.exists():
return
for child in sorted(root.iterdir()):
if child.is_dir() and child.name.startswith("episode_"):
try:
yield int(child.name.removeprefix("episode_"))
except ValueError:
continue
@@ -0,0 +1,271 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre-write validation against staged outputs.
Runs after Modules 13 have all written their per-episode artifacts but
*before* the writer rewrites parquet shards. The validator never touches
parquet; it only inspects the staging tree and the source frame timestamps
exposed by :class:`EpisodeRecord`.
Checks (per the plan's "Intermediate staging and validation" section):
- exact timestamp alignment against source frame timestamps
- no orphan speech / interjection pairs
- plan / memory emission consistency (events have a paired persistent row)
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
attribute / spatial)
- every row maps to its correct column under :func:`column_for_style`
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import (
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
column_for_style,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
@dataclass
class ValidationReport:
"""Outcome of one validation pass across all episodes."""
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
episodes_checked: int = 0
@property
def ok(self) -> bool:
return not self.errors
def add_error(self, message: str) -> None:
self.errors.append(message)
def add_warning(self, message: str) -> None:
self.warnings.append(message)
def summary(self) -> str:
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
def classify_vqa_answer(payload: Any) -> str | None:
"""Best-effort classification of a VQA answer payload to a question type."""
if not isinstance(payload, dict):
return None
keys = set(payload.keys())
for kind, required in VQA_ANSWER_SHAPES.items():
if required.issubset(keys):
return kind
return None
@dataclass
class StagingValidator:
"""Walks the staging tree and produces a :class:`ValidationReport`."""
timestamp_atol: float = 0.0 # exact-match by default
def validate(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
) -> ValidationReport:
report = ValidationReport()
for record in records:
self._validate_episode(record, staging_dir, report)
report.episodes_checked += 1
return report
def _validate_episode(
self,
record: EpisodeRecord,
staging_dir: Path,
report: ValidationReport,
) -> None:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged = staging.read_all()
all_rows: list[dict[str, Any]] = []
for module_name, rows in staged.items():
for row in rows:
row = {**row, "_module": module_name}
all_rows.append(row)
frame_ts = set(record.frame_timestamps)
events: list[dict[str, Any]] = []
persistent: list[dict[str, Any]] = []
for row in all_rows:
self._check_column_routing(row, report, record.episode_index)
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
events.append(row)
for row in events:
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
self._check_speech_interjection_pairs(events, report, record.episode_index)
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
self._check_vqa_json(events, report, record.episode_index)
def _check_column_routing(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
) -> None:
style = row.get("style")
module = row.get("_module")
try:
target_col = column_for_style(style)
except ValueError:
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
return
if module == "module_1" and target_col != LANGUAGE_PERSISTENT:
report.add_error(
f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)"
)
if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS:
report.add_error(
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
)
def _check_event_timestamp_alignment(
self,
row: dict[str, Any],
frame_ts: set[float],
report: ValidationReport,
episode_index: int,
) -> None:
ts = row.get("timestamp")
if ts is None:
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
return
if self.timestamp_atol == 0.0:
if float(ts) not in frame_ts:
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
)
else:
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
)
def _check_speech_interjection_pairs(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
speech_ts: dict[float, int] = {}
interjection_ts: dict[float, int] = {}
for row in events:
ts = row.get("timestamp")
if ts is None:
continue
ts_f = float(ts)
if row.get("style") is None and row.get("role") == "assistant":
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
if row.get("style") == "interjection":
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
for ts in interjection_ts:
if ts not in speech_ts:
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
def _check_plan_memory_consistency(
self,
persistent: Sequence[dict[str, Any]],
events: Sequence[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
interjection_ts = sorted(
{
float(r["timestamp"])
for r in events
if r.get("style") == "interjection" and r.get("timestamp") is not None
}
)
if persistent and not plan_ts:
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
# every interjection should have a same-timestamp plan refresh
for ts in interjection_ts:
if ts not in set(plan_ts):
report.add_error(
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
)
# memory should be emitted at subtask boundaries (subset relation)
if memory_ts and subtask_ts:
mem_set = set(memory_ts)
sub_set = set(subtask_ts)
stray = sorted(mem_set - sub_set)
if stray:
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
def _check_vqa_json(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
for row in events:
if row.get("style") != "vqa" or row.get("role") != "assistant":
continue
content = row.get("content")
if content is None:
report.add_error(
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
)
continue
try:
payload = json.loads(content)
except (TypeError, ValueError) as exc:
report.add_error(
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
)
continue
shape = classify_vqa_answer(payload)
if shape is None:
report.add_error(
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
)
@@ -0,0 +1,204 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared Qwen-VL client.
The pipeline uses a single shared VLM across modules. vLLM is preferred when
available (high throughput, JSON-guided decoding); transformers is the
fallback. A ``stub`` backend is used for unit tests so fixtures never call
into a real model.
The client speaks one method, :meth:`VlmClient.generate_json`, which:
- accepts a list of OpenAI/HF-style multimodal messages,
- requests JSON output (``json_mode=True`` enables guided decoding when the
backend supports it),
- batches requests transparently,
- and reprompts once on a JSON parse failure with an inline correction
message before raising.
"""
from __future__ import annotations
import json
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any, Protocol
from .config import VlmConfig
class VlmClient(Protocol):
"""Protocol every backend must implement."""
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
"""Generate one JSON-decoded response per messages list."""
@dataclass
class StubVlmClient:
"""Deterministic stub used in unit tests.
A test passes a callable that maps the *last user message text* (or, if
that is empty, the full message list) to a JSON-serializable response.
"""
responder: Callable[[Sequence[dict[str, Any]]], Any]
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
return [self.responder(list(messages)) for messages in messages_batch]
def _strip_to_json(text: str) -> Any:
text = text.strip()
if text.startswith("```"):
# tolerate ```json ... ``` fences from chat-tuned backbones
first = text.find("\n")
last = text.rfind("```")
if first != -1 and last != -1 and last > first:
text = text[first + 1 : last].strip()
return json.loads(text)
@dataclass
class _GenericTextClient:
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
config: VlmConfig
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
temp = temperature if temperature is not None else self.config.temperature
raw = self.generate_text(messages_batch, max_tok, temp)
out: list[Any] = []
for messages, text in zip(messages_batch, raw, strict=True):
try:
out.append(_strip_to_json(text))
continue
except (ValueError, json.JSONDecodeError):
pass
retry = list(messages) + [
{"role": "assistant", "content": text},
{
"role": "user",
"content": (
"Your previous reply was not valid JSON. "
"Reply with strictly valid JSON, no prose, no fences."
),
},
]
retry_text = self.generate_text([retry], max_tok, temp)[0]
out.append(_strip_to_json(retry_text))
return out
def make_vlm_client(config: VlmConfig) -> VlmClient:
"""Build the shared VLM client per the configured backend.
For ``stub``, callers should construct :class:`StubVlmClient` directly with
a responder callable. ``stub`` here is rejected to make accidental misuse
obvious.
"""
if config.backend == "stub":
raise ValueError(
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
)
if config.backend == "vllm":
return _make_vllm_client(config)
if config.backend == "transformers":
return _make_transformers_client(config)
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
def _make_vllm_client(config: VlmConfig) -> VlmClient:
try:
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError(
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
) from exc
llm = LLM(model=config.model_id)
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
params = SamplingParams(
max_tokens=max_tok,
temperature=temp,
guided_decoding={"json": {}} if config.json_mode else None,
)
prompts = [_messages_to_prompt(m) for m in batch]
outputs = llm.generate(prompts, params)
return [o.outputs[0].text for o in outputs]
return _GenericTextClient(_gen, config)
def _make_transformers_client(config: VlmConfig) -> VlmClient:
try:
import torch # type: ignore[import-not-found]
from transformers import AutoModelForVision2Seq, AutoProcessor # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
processor = AutoProcessor.from_pretrained(config.model_id)
model = AutoModelForVision2Seq.from_pretrained(config.model_id, torch_dtype="auto")
model.eval()
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
outs: list[str] = []
for messages in batch:
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = processor(text=[text], return_tensors="pt").to(model.device)
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=max_tok,
temperature=temp,
do_sample=temp > 0.0,
)
decoded = processor.batch_decode(
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
)[0]
outs.append(decoded)
return outs
return _GenericTextClient(_gen, config)
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
"""Pass-through hook used by the vllm backend.
vllm exposes its own multimodal entry points that vary by version; for the
base flow we simply forward the raw message list and let the caller's
custom backend handle templating. Real deployments override this.
"""
return list(messages)
@@ -0,0 +1,339 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Final parquet rewrite.
For every episode the writer:
1. reads the staged module outputs,
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
3. sorts each slice deterministically,
4. broadcasts the persistent slice across every frame in the episode,
5. for each frame, materializes the sublist of event rows whose timestamp
exactly equals that frame's timestamp,
6. drops the legacy ``subtask_index`` column,
7. adds a top-level ``tools`` column containing the JSON schema for ``say``,
8. writes the parquet shard back in place.
Invariants enforced here (and re-checked by the validator):
- per-episode persistent slice is byte-identical across every frame;
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
(timestamps come straight from the source parquet — never recomputed);
- every row passes ``column_for_style(style)``.
"""
from __future__ import annotations
import json
import logging
from collections import defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
SAY_TOOL_SCHEMA: dict[str, Any] = {
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak.",
}
},
"required": ["text"],
},
},
}
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
# events are bucketed per-frame, but within a frame we still want determinism
return (row.get("style") or "", row.get("role") or "")
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the persistent column's struct shape."""
style = row.get("style")
if style not in PERSISTENT_STYLES:
raise ValueError(
f"persistent slice contains row with non-persistent style {style!r}; "
"row would be misrouted under column_for_style()"
)
if "timestamp" not in row:
raise ValueError(f"persistent row missing timestamp: {row!r}")
return {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
"timestamp": float(row["timestamp"]),
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
}
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
style = row.get("style")
if style is not None and style not in EVENT_ONLY_STYLES:
raise ValueError(
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
)
if column_for_style(style) != LANGUAGE_EVENTS:
raise ValueError(f"event row with style {style!r} would not route to language_events")
return {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
}
def _normalize_tool_calls(value: Any) -> list[Any] | None:
if value is None:
return None
if not isinstance(value, list):
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
return list(value)
def _validate_atom_invariants(row: dict[str, Any]) -> None:
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
has_content = row.get("content") is not None
has_tools = row.get("tool_calls") is not None
if not (has_content or has_tools):
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
if row.get("style") is None and not has_tools:
raise ValueError(f"style=None requires tool_calls: {row!r}")
def _validate_speech_atom(row: dict[str, Any]) -> None:
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
if row.get("style") is not None:
return # not a speech atom
if row.get("role") != "assistant":
raise ValueError(f"speech atom must have role=assistant: {row!r}")
if row.get("content") is not None:
raise ValueError(f"speech atom must have content=null: {row!r}")
tool_calls = row.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
first = tool_calls[0]
if not isinstance(first, dict):
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
if first.get("type") != "function":
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
fn = first.get("function") or {}
if fn.get("name") != "say":
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
args = fn.get("arguments") or {}
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
@dataclass
class LanguageColumnsWriter:
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
drop_existing_subtask_index: bool = True
def write_all(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> list[Path]:
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
for record in records:
episodes_by_path[record.data_path].append(record)
written: list[Path] = []
for path, eps in episodes_by_path.items():
self._rewrite_one(path, eps, staging_dir, root)
written.append(path)
return written
def _rewrite_one(
self,
path: Path,
episodes: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> None:
table = pq.read_table(path)
n_rows = table.num_rows
# Ensure we cover every episode in the file. Episodes that don't have
# staging artifacts are passed through with empty annotation lists —
# this keeps the writer idempotent and safe for partial reruns.
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
for record in episodes:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged_per_ep[record.episode_index] = staging.read_all()
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
for ep_index, ep_staged in staged_per_ep.items():
persistent_rows: list[dict[str, Any]] = []
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
for _module_name, rows in ep_staged.items():
for row in rows:
style = row.get("style")
if column_for_style(style) == LANGUAGE_PERSISTENT:
persistent_rows.append(row)
else:
event_rows.append(row)
persistent_rows.sort(key=_row_persistent_sort_key)
normalized_persistent = []
for r in persistent_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
normalized_persistent.append(_normalize_persistent_row(r))
persistent_by_ep[ep_index] = normalized_persistent
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
for r in event_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
ts = float(r["timestamp"])
buckets[ts].append(_normalize_event_row(r))
for ts in list(buckets.keys()):
buckets[ts].sort(key=_row_event_sort_key)
events_by_ep_ts[ep_index] = buckets
episode_col = (
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
)
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
if episode_col is None or ts_col is None:
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
per_row_persistent: list[list[dict[str, Any]]] = []
per_row_events: list[list[dict[str, Any]]] = []
for i in range(n_rows):
ep = episode_col[i]
ts = float(ts_col[i])
per_row_persistent.append(persistent_by_ep.get(ep, []))
buckets = events_by_ep_ts.get(ep, {})
per_row_events.append(buckets.get(ts, []))
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
pq.write_table(new_table, path)
def _materialize_table(
self,
table: pa.Table,
persistent: list[list[dict[str, Any]]],
events: list[list[dict[str, Any]]],
*,
drop_old: bool,
) -> pa.Table:
cols = []
names = []
for name in table.column_names:
if drop_old and name == "subtask_index":
continue
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS, "tools"):
continue # we'll re-add canonical versions
cols.append(table.column(name))
names.append(name)
# We let pyarrow infer struct/list schema rather than passing the
# canonical type from `lerobot.datasets.language` directly: that type
# uses `pa.json_()` for the `tool_calls` element type, which
# `pa.array(..., type=...)` cannot materialize from Python lists on
# current pyarrow versions. The inferred schema round-trips through
# parquet and `LeRobotDataset` correctly — see PR 1's
# `tests/datasets/test_language.py` which exercises the same flow.
persistent_arr = pa.array(persistent)
events_arr = pa.array(events)
cols.extend([persistent_arr, events_arr])
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
# Dataset-level tools column. Store the JSON schema as a string per
# row (broadcast-identical, parquet dictionary-encodes it) — string
# storage avoids requiring pa.json_() on every consumer.
tools_json = json.dumps([SAY_TOOL_SCHEMA], sort_keys=True)
tools_arr = pa.array([tools_json] * table.num_rows, type=pa.string())
cols.append(tools_arr)
names.append("tools")
return pa.Table.from_arrays(cols, names=names)
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
"""Build a canonical speech tool-call atom for the events column."""
return {
"role": "assistant",
"content": None,
"style": None,
"timestamp": float(timestamp),
"tool_calls": [
{
"type": "function",
"function": {
"name": "say",
"arguments": {"text": text},
},
}
],
}
def normalize_rows_for_writer(
rows: Iterable[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Helper used by tests/validators to partition a flat row list into
(persistent_rows, event_rows) using ``column_for_style``.
"""
persistent: list[dict[str, Any]] = []
events: list[dict[str, Any]] = []
for row in rows:
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
events.append(row)
return persistent, events
+100
View File
@@ -0,0 +1,100 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``lerobot-annotate`` — populate ``language_persistent`` and
``language_events`` columns on a LeRobot dataset.
Annotations live directly in ``data/chunk-*/file-*.parquet``: there is no
flavor namespace and no sidecar tree. Multiple revisions of the same dataset
mean multiple dataset copies.
Example:
uv run lerobot-annotate \\
--root=/path/to/dataset \\
--vlm.backend=transformers \\
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
"""
import logging
from pathlib import Path
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.modules import (
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
)
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
from lerobot.configs import parser
logger = logging.getLogger(__name__)
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
if cfg.root is not None:
return Path(cfg.root)
if cfg.repo_id is not None:
from huggingface_hub import snapshot_download
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
raise ValueError("Either --root or --repo_id must be provided.")
@parser.wrap()
def annotate(cfg: AnnotationPipelineConfig) -> None:
"""Run the steerable annotation pipeline against a dataset."""
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
root = _resolve_root(cfg)
logger.info("annotate: root=%s", root)
vlm = make_vlm_client(cfg.vlm)
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1)
module_2 = InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed)
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed)
writer = LanguageColumnsWriter()
validator = StagingValidator()
executor = Executor(
config=cfg,
module_1=module_1,
module_2=module_2,
module_3=module_3,
writer=writer,
validator=validator,
)
summary = executor.run(root)
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
for phase in summary.phases:
logger.info(
"annotate: phase=%s processed=%d skipped=%d",
phase.name,
phase.episodes_processed,
phase.episodes_skipped,
)
if summary.validation_report.warnings:
for w in summary.validation_report.warnings:
logger.warning(w)
def main() -> None:
annotate()
if __name__ == "__main__":
main()
View File
+58
View File
@@ -0,0 +1,58 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers shared across annotation-pipeline tests."""
from __future__ import annotations
import json
from typing import Any
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
def make_canned_responder(
responses_by_marker: dict[str, Any],
default: Any = None,
) -> StubVlmClient:
"""Return a stub that picks a response by inspecting the user prompt.
For each call the responder examines the last user-message text and
returns the response keyed by the first marker substring it contains.
Falls back to ``default`` if no marker matches.
"""
def responder(messages: list[dict[str, Any]]) -> Any:
last_user_text = ""
for message in messages:
if message.get("role") != "user":
continue
content = message.get("content")
if isinstance(content, str):
last_user_text = content
elif isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
last_user_text = block.get("text", "")
for marker, response in responses_by_marker.items():
if marker in last_user_text:
return response
return default
return StubVlmClient(responder=responder)
def encode_vqa_answer(payload: dict[str, Any]) -> str:
return json.dumps(payload, sort_keys=True)
+112
View File
@@ -0,0 +1,112 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared fixtures for annotation-pipeline tests.
Builds a minimal LeRobot-shaped dataset on disk so writer/validator tests
can exercise real parquet reads and writes without needing a checked-in
LFS dataset.
"""
from __future__ import annotations
import json
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
def _make_episode_table(
episode_index: int,
num_frames: int,
*,
fps: int = 10,
task_index: int = 0,
) -> pa.Table:
timestamps = [round(i / fps, 6) for i in range(num_frames)]
frame_indices = list(range(num_frames))
return pa.Table.from_pydict(
{
"episode_index": [episode_index] * num_frames,
"frame_index": frame_indices,
"timestamp": timestamps,
"task_index": [task_index] * num_frames,
"subtask_index": [0] * num_frames, # legacy column the writer must drop
}
)
def _build_dataset(root: Path, episode_specs: list[tuple[int, int, str]], *, fps: int = 10) -> Path:
"""Create a fixture dataset under ``root``.
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
Each episode goes into its own ``data/chunk-000/file-{ep:03d}.parquet``
so the writer's per-shard rewrite path is exercised.
"""
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
tasks = {}
for episode_index, num_frames, task_text in episode_specs:
task_index = len(tasks)
if task_text not in tasks.values():
tasks[task_index] = task_text
else:
task_index = next(k for k, v in tasks.items() if v == task_text)
table = _make_episode_table(episode_index, num_frames, fps=fps, task_index=task_index)
path = data_dir / f"file-{episode_index:03d}.parquet"
pq.write_table(table, path)
meta_dir = root / "meta"
meta_dir.mkdir(parents=True, exist_ok=True)
tasks_table = pa.Table.from_pydict(
{
"task_index": list(tasks.keys()),
"task": list(tasks.values()),
}
)
pq.write_table(tasks_table, meta_dir / "tasks.parquet")
info = {
"codebase_version": "v3.1",
"fps": fps,
"total_episodes": len(episode_specs),
}
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
return root
@pytest.fixture
def fixture_dataset_root(tmp_path: Path) -> Path:
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
return _build_dataset(
tmp_path / "ds",
episode_specs=[
(0, 12, "Could you tidy the kitchen please?"),
(1, 12, "Please clean up the kitchen"),
],
fps=10,
)
@pytest.fixture
def single_episode_root(tmp_path: Path) -> Path:
return _build_dataset(
tmp_path / "ds_one",
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
fps=10,
)
+124
View File
@@ -0,0 +1,124 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Opt-in E2E smoke run for ``make annotation-e2e``.
Builds the same fixture used by the pytest suite, runs the full
annotation pipeline against it with a stub VLM, and prints a short report.
This is intentionally not a pytest test — it exercises the CLI plumbing
without depending on conftest.py fixtures.
"""
from __future__ import annotations
import json
import sys
import tempfile
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.modules import (
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
)
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
def _build_dataset(root: Path) -> Path:
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
n = 30
timestamps = [round(i / 10, 6) for i in range(n)]
table = pa.Table.from_pydict(
{
"episode_index": [0] * n,
"frame_index": list(range(n)),
"timestamp": timestamps,
"task_index": [0] * n,
"subtask_index": [0] * n,
}
)
pq.write_table(table, data_dir / "file-000.parquet")
meta = root / "meta"
meta.mkdir(parents=True, exist_ok=True)
pq.write_table(
pa.Table.from_pydict({"task_index": [0], "task": ["Pour water into the cup."]}),
meta / "tasks.parquet",
)
(meta / "info.json").write_text(json.dumps({"codebase_version": "v3.1", "fps": 10}))
return root
def _stub_responder(messages):
text = ""
for m in messages:
if m.get("role") == "user":
content = m.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text", "")
elif isinstance(content, str):
text = content
if "Decompose the demonstration" in text:
return {
"subtasks": [
{"text": "grasp the bottle", "start": 0.0, "end": 1.0},
{"text": "pour into the cup", "start": 1.0, "end": 2.0},
{"text": "place the bottle down", "start": 2.0, "end": 3.0},
]
}
if "concise hierarchical PLAN" in text:
return {"plan": "1. grasp\n2. pour\n3. place"}
if "Update the memory" in text:
return {"memory": "poured once"}
if "acknowledgement the robot" in text:
return {"text": "Sure."}
if "ONE realistic interruption" in text:
return {"interjection": "use less water", "speech": "Using less water."}
if "frame-grounded visual question" in text:
return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}}
return None
def main() -> int:
with tempfile.TemporaryDirectory() as tmp:
root = _build_dataset(Path(tmp) / "ds")
vlm = StubVlmClient(responder=_stub_responder)
cfg = AnnotationPipelineConfig()
executor = Executor(
config=cfg,
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1),
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed),
module_3=GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)
summary = executor.run(root)
print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}")
print(f"validation: {summary.validation_report.summary()}")
print(f"shards rewritten: {len(summary.written_paths)}")
return 0
if __name__ == "__main__":
sys.exit(main())
+166
View File
@@ -0,0 +1,166 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 1/2/3 unit tests with stubbed VLMs."""
from __future__ import annotations
import json
from pathlib import Path
from lerobot.annotations.steerable_pipeline.config import (
Module1Config,
Module2Config,
Module3Config,
)
from lerobot.annotations.steerable_pipeline.modules import (
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
)
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
from ._helpers import make_canned_responder
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
vlm = make_canned_responder(
{
"Decompose the demonstration": {
"subtasks": [
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
]
},
"write a concise hierarchical PLAN": {"plan": "1. grasp\n2. wipe\n3. place"},
"Update the memory": {"memory": "wiped the counter once"},
},
)
module = PlanSubtasksMemoryModule(vlm=vlm, config=Module1Config())
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_1")
styles = {r["style"] for r in rows}
assert {"subtask", "plan", "memory"}.issubset(styles)
# subtask timestamps must be exact frame timestamps
frame_set = set(record.frame_timestamps)
for row in rows:
assert row["timestamp"] in frame_set
# exactly one plan row at t0
plan_rows = [r for r in rows if r["style"] == "plan"]
assert len(plan_rows) == 1
assert plan_rows[0]["timestamp"] == record.frame_timestamps[0]
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
vlm = make_canned_responder(
{"acknowledgement the robot": {"text": "Sure, on it."}},
)
module = InterjectionsAndSpeechModule(
vlm=vlm,
config=Module2Config(max_interjections_per_episode=0),
)
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_2")
assert len(rows) == 1
only = rows[0]
assert only["role"] == "assistant"
assert only["style"] is None
assert only["content"] is None
assert only["timestamp"] == record.frame_timestamps[0]
assert only["tool_calls"][0]["function"]["name"] == "say"
def test_module2_mid_episode_emits_paired_interjection_and_speech(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
vlm = make_canned_responder(
{
"acknowledgement the robot": {"text": "OK."},
"ONE realistic interruption": {
"interjection": "actually skip the dishes",
"speech": "Skipping the dishes.",
},
},
)
module = InterjectionsAndSpeechModule(
vlm=vlm,
config=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.2),
seed=7,
)
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_2")
interjections = [r for r in rows if r["style"] == "interjection"]
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
assert len(interjections) == 1
assert len(speeches) >= 2 # initial t=0 + one paired with the interjection
inter_t = interjections[0]["timestamp"]
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) -> None:
payload = {
"question": "How many cups?",
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
}
vlm = make_canned_responder({"frame-grounded visual question": payload})
module = GeneralVqaModule(
vlm=vlm,
config=Module3Config(vqa_emission_hz=1.0, K=3),
seed=1,
)
record = next(iter_episodes(single_episode_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_3")
user_ts = [r["timestamp"] for r in rows if r["role"] == "user" and r["style"] == "vqa"]
assistant_ts = [r["timestamp"] for r in rows if r["role"] == "assistant" and r["style"] == "vqa"]
# at most one user (vqa) per frame; same for assistant
assert len(user_ts) == len(set(user_ts))
assert len(assistant_ts) == len(set(assistant_ts))
# every emitted timestamp must be an exact source frame timestamp
frame_set = set(record.frame_timestamps)
for ts in user_ts + assistant_ts:
assert ts in frame_set
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
payload = {
"question": "Where is the cup?",
"answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]},
}
vlm = make_canned_responder({"frame-grounded visual question": payload})
module = GeneralVqaModule(
vlm=vlm,
config=Module3Config(vqa_emission_hz=1.0, K=2),
seed=2,
)
record = next(iter_episodes(single_episode_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_3")
for row in rows:
if row["role"] == "assistant" and row["style"] == "vqa":
decoded = json.loads(row["content"])
assert "detections" in decoded
@@ -0,0 +1,135 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end smoke: pipeline output → PR 1 canonical recipe rendering."""
from __future__ import annotations
import json
from pathlib import Path
import pyarrow.parquet as pq
from lerobot.annotations.steerable_pipeline.config import (
AnnotationPipelineConfig,
Module1Config,
Module2Config,
Module3Config,
)
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.modules import (
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
)
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language_render import render_sample
from ._helpers import make_canned_responder
_RECIPE_PATH = (
Path(__file__).resolve().parents[2] / "src" / "lerobot" / "configs" / "recipes" / "pi05_hirobot.yaml"
)
def _build_executor() -> Executor:
vlm = make_canned_responder(
{
"Decompose the demonstration": {
"subtasks": [
{"text": "grasp the bottle", "start": 0.0, "end": 0.5},
{"text": "pour into the cup", "start": 0.5, "end": 1.0},
{"text": "place the bottle down", "start": 1.0, "end": 1.5},
]
},
"write a concise hierarchical PLAN": {"plan": "1. grasp\n2. pour\n3. place"},
"Update the memory": {"memory": "poured once"},
"acknowledgement the robot": {"text": "Sure."},
"ONE realistic interruption": {
"interjection": "use less water",
"speech": "Using less water.",
},
"frame-grounded visual question": {
"question": "How many cups?",
"answer": {"label": "cup", "count": 1},
},
},
)
config = AnnotationPipelineConfig(
module_1=Module1Config(),
module_2=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.5),
module_3=Module3Config(vqa_emission_hz=1.0, K=2),
)
return Executor(
config=config,
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=config.module_1),
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=config.module_2, seed=config.seed),
module_3=GeneralVqaModule(vlm=vlm, config=config.module_3, seed=config.seed),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)
def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
single_episode_root: Path,
) -> None:
executor = _build_executor()
summary = executor.run(single_episode_root)
# validator may emit warnings but no errors for the synthetic fixture
assert summary.validation_report.ok, summary.validation_report.summary()
table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet")
persistent_lists = table.column("language_persistent").to_pylist()
events_lists = table.column("language_events").to_pylist()
timestamps = table.column("timestamp").to_pylist()
recipe = TrainingRecipe.from_yaml(_RECIPE_PATH) if hasattr(TrainingRecipe, "from_yaml") else None
if recipe is None:
# PR 1 may not expose from_yaml; load via PyYAML and TrainingRecipe(**...)
import yaml
loaded = yaml.safe_load(_RECIPE_PATH.read_text(encoding="utf-8"))
recipe = TrainingRecipe(**loaded)
rendered_any = False
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
result = render_sample(
recipe=recipe,
persistent=persistent,
events=events,
t=float(ts),
sample_idx=0,
dataset_ctx={"task": "Pour water from the bottle into the cup."},
)
if result is None:
continue
if result["messages"]:
rendered_any = True
assert result["target_message_indices"]
break
assert rendered_any, "PR 1 recipe rendered no messages from pipeline output"
# Sanity: speech atom appears in events column intact
flat_events = [r for ev in events_lists for r in ev]
speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"]
assert speech_rows
say = speech_rows[0]["tool_calls"][0]
assert say["function"]["name"] == "say"
assert isinstance(say["function"]["arguments"]["text"], str)
# Tools column carries the say schema
tools = json.loads(table.column("tools").to_pylist()[0])
assert tools and tools[0]["function"]["name"] == "say"
+125
View File
@@ -0,0 +1,125 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Validator behavior tests."""
from __future__ import annotations
import json
from pathlib import Path
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
from lerobot.annotations.steerable_pipeline.writer import speech_atom
def _validate(root: Path, staging_dir: Path):
records = list(iter_episodes(root))
return StagingValidator().validate(records, staging_dir)
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_3",
[
{
"role": "assistant",
"content": json.dumps({"label": "cup", "count": 2}, sort_keys=True),
"style": "vqa",
"timestamp": 9.999, # not on any 10 fps frame
"tool_calls": None,
}
],
)
report = _validate(fixture_dataset_root, staging_dir)
assert not report.ok
assert any("does not match any source frame timestamp" in e for e in report.errors)
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_2",
[
speech_atom(0.0, "Got it."),
# interjection at 0.3s with NO paired speech
{
"role": "user",
"content": "skip it",
"style": "interjection",
"timestamp": 0.3,
"tool_calls": None,
},
],
)
report = _validate(fixture_dataset_root, staging_dir)
assert not report.ok
assert any("paired speech" in e for e in report.errors)
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_1",
[
{
"role": "assistant",
"content": "1. do x",
"style": "plan",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "do x",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
],
)
EpisodeStaging(staging_dir, 0).write(
"module_2",
[
speech_atom(0.0, "Got it."),
speech_atom(0.4, "Replanning."),
{
"role": "user",
"content": "replan",
"style": "interjection",
"timestamp": 0.4,
"tool_calls": None,
},
],
)
report = _validate(fixture_dataset_root, staging_dir)
# missing co-timestamped plan refresh at 0.4s → error
assert not report.ok
assert any("co-timestamped plan update" in e for e in report.errors)
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_1",
[
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
],
)
report = _validate(fixture_dataset_root, staging_dir)
assert not report.ok
assert any("module_1 emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
+283
View File
@@ -0,0 +1,283 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Writer correctness tests."""
from __future__ import annotations
import json
from pathlib import Path
import pyarrow.parquet as pq
import pytest
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
from lerobot.annotations.steerable_pipeline.writer import (
LanguageColumnsWriter,
speech_atom,
)
def _stage_episode(
staging_dir: Path,
episode_index: int,
*,
module_1: list[dict] | None = None,
module_2: list[dict] | None = None,
module_3: list[dict] | None = None,
) -> None:
staging = EpisodeStaging(staging_dir, episode_index)
if module_1 is not None:
staging.write("module_1", module_1)
if module_2 is not None:
staging.write("module_2", module_2)
if module_3 is not None:
staging.write("module_3", module_3)
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
"""Every frame in an episode has a byte-identical persistent list."""
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{
"role": "assistant",
"content": "grasp the sponge",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "1. wipe\n2. dry",
"style": "plan",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "wiped the counter",
"style": "memory",
"timestamp": 0.5,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
persistent = table.column("language_persistent").to_pylist()
first = persistent[0]
assert first # non-empty
for row in persistent:
assert row == first, "persistent slice must be byte-identical across all frames"
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_2=[
speech_atom(0.0, "Got it."),
{
"role": "user",
"content": "skip the dishes",
"style": "interjection",
"timestamp": 0.5,
"tool_calls": None,
},
speech_atom(0.5, "Skipping the dishes."),
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
timestamps = table.column("timestamp").to_pylist()
events = table.column("language_events").to_pylist()
for ts, ev in zip(timestamps, events, strict=True):
if abs(ts - 0.0) < 1e-9:
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
elif abs(ts - 0.5) < 1e-9:
assert any(r.get("style") == "interjection" for r in ev), ev
assert any(r.get("style") is None for r in ev), ev
else:
assert ev == []
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{
"role": "assistant",
"content": "do X",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "1. do X",
"style": "plan",
"timestamp": 0.0,
"tool_calls": None,
},
{
"role": "assistant",
"content": "did X",
"style": "memory",
"timestamp": 0.3,
"tool_calls": None,
},
],
module_2=[
speech_atom(0.0, "OK"),
{
"role": "user",
"content": "wait",
"style": "interjection",
"timestamp": 0.2,
"tool_calls": None,
},
speech_atom(0.2, "Waiting"),
],
module_3=[
{
"role": "user",
"content": "where is the cup?",
"style": "vqa",
"timestamp": 0.4,
"tool_calls": None,
},
{
"role": "assistant",
"content": json.dumps(
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
sort_keys=True,
),
"style": "vqa",
"timestamp": 0.4,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
persistent = table.column("language_persistent").to_pylist()[0]
persistent_styles = {r["style"] for r in persistent}
assert persistent_styles == {"subtask", "plan", "memory"}
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
event_styles = {r.get("style") for r in all_events}
assert event_styles == {None, "interjection", "vqa"}
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{
"role": "assistant",
"content": "do X",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
},
],
)
records = list(iter_episodes(fixture_dataset_root))
writer = LanguageColumnsWriter()
writer.write_all(records, staging_dir, fixture_dataset_root)
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
table_a = pq.read_table(path)
assert "subtask_index" not in table_a.column_names
assert "language_persistent" in table_a.column_names
assert "language_events" in table_a.column_names
assert "tools" in table_a.column_names
# second pass — must produce identical bytes for the language columns
records_again = list(iter_episodes(fixture_dataset_root))
writer.write_all(records_again, staging_dir, fixture_dataset_root)
table_b = pq.read_table(path)
assert (
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
)
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
"""``_normalize_persistent_row`` must reject any non-persistent style."""
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
with pytest.raises(ValueError, match="non-persistent style"):
_normalize_persistent_row(
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
)
def test_writer_normalize_rejects_misrouted_event_style() -> None:
"""``_normalize_event_row`` must reject any persistent style."""
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
with pytest.raises(ValueError):
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
_stage_episode(
staging_dir,
0,
module_1=[
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
],
)
records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
tools = table.column("tools").to_pylist()
assert tools, "tools column missing"
decoded = json.loads(tools[0])
assert isinstance(decoded, list)
assert len(decoded) == 1
assert decoded[0]["function"]["name"] == "say"
params = decoded[0]["function"]["parameters"]
assert params["properties"]["text"]["type"] == "string"
def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant"
assert atom["style"] is None
assert atom["content"] is None
assert atom["timestamp"] == 2.5
assert isinstance(atom["tool_calls"], list)
call = atom["tool_calls"][0]
assert call["type"] == "function"
assert call["function"]["name"] == "say"
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"