review: address CarolinePascal feedback

- name the three modules everywhere (plan / interjections / vqa) instead
  of module_1/2/3 — config classes, config fields, executor params,
  staging keys and phase names now carry the module name
- rename examples/annotation -> examples/annotations; add the Apache
  header to run_hf_job.py
- drop the unused GeneralVqaModule._generate_one
- remove "PR 1" references from comments/docstrings
- frames.py: rely on the always-defined LeRobotDatasetMetadata.camera_keys
- executor.py: read/write meta/info.json via load_info / write_info
- reader.py: load meta/tasks.parquet via io_utils.load_tasks
- make --push_to_hub a bool; push the annotated dataset back to --repo_id
- move the on-disk test dataset builder into tests/fixtures
  (build_annotation_dataset); run_e2e_smoke reuses it
- clarify in the docs that the vqa module grounds each pair on a single
  frame (K = per-tick anchor count)
- hoist stdlib dynamic imports to module scope

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn Kooijmans
2026-05-18 12:03:25 +02:00
parent 965d42825f
commit fd18beb3a1
23 changed files with 383 additions and 412 deletions
+35 -29
View File
@@ -11,14 +11,14 @@ Three modules write into a per-episode staging tree, then a single writer
rewrites the data shards in place:
| Style / atom | Column | Module |
| ------------------------------------------- | --------------------- | -------- |
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | Module 1 |
| `plan` (initial + refresh on interjection) | `language_persistent` | Module 1 |
| `memory` (MEM-style compression) | `language_persistent` | Module 1 |
| `task_aug` (rephrasings of canonical task) | `language_persistent` | Module 1 |
| `interjection` | `language_events` | Module 2 |
| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 |
| `vqa` (user / assistant pair) | `language_events` | Module 3 |
| ------------------------------------------- | --------------------- | -------------- |
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
| `task_aug` (rephrasings of canonical task) | `language_persistent` | `plan` |
| `interjection` | `language_events` | `interjections`|
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections`|
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
The writer does **not** add a `tools` column to the parquet — the tool
catalog lives at `meta/info.json["tools"]` instead (see
@@ -45,20 +45,24 @@ uv run lerobot-annotate \
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
```
The pipeline attaches actual camera footage to every Module 1/2/3 prompt
by default, decoded from the dataset's first `observation.images.*`
stream. Override with `--vlm.camera_key=observation.images.<name>` to
pin a specific viewpoint. Datasets with no video tracks fall back to
text-only prompts automatically.
The pipeline attaches actual camera footage to every `plan` /
`interjections` / `vqa` prompt by default, decoded from the dataset's
first `observation.images.*` stream. Override with
`--vlm.camera_key=observation.images.<name>` to pin a specific
viewpoint. Datasets with no video tracks fall back to text-only prompts
automatically.
**Module 1 sees the whole episode as one video block.** Subtask
**The `plan` module sees the whole episode as one video block.** Subtask
decomposition gets a `{"type":"video", "video":[<frames>]}` block
covering the entire demonstration; Qwen-VL pools temporally on its own
and decides where to cut. There is no keyframe stride or count knob —
`--module_1.max_video_frames` (default 128) only caps the frames packed
into the video block as a model-capacity bound. Module 2 attaches a
short window of frames around the interjection timestamp; Module 3
attaches the exact emission frame to each VQA pair.
`--plan.max_video_frames` (default 128) only caps the frames packed
into the video block as a model-capacity bound. The `interjections`
module attaches a short window of frames straddling the interjection
timestamp. The `vqa` module grounds each VQA pair on a single frame —
its `--vqa.K` knob sets how many consecutive frames each emission tick
anchors, and every anchored frame gets its own VQA pair on that one
frame (there is no per-pair frame window).
## Running on Hugging Face Jobs
@@ -67,15 +71,16 @@ Distributed annotation is delegated to
ships a launcher script you copy and edit for your dataset:
```bash
HF_TOKEN=hf_... uv run python examples/annotation/run_hf_job.py
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
```
[`examples/annotation/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotation/run_hf_job.py)
[`examples/annotations/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
spawns one `h200x2` job that:
1. installs the branch under test plus the annotation extras,
2. boots two vllm servers (one per GPU) for the chosen model,
3. runs Modules 1 / 2 / 3 across the dataset via `lerobot-annotate`,
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
via `lerobot-annotate`,
4. uploads the annotated dataset to `--push_to_hub`.
To target a different dataset, model, or hub repo, edit the `CMD` block
@@ -115,7 +120,8 @@ Each module writes its raw output to
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
prompt iteration cheap — re-running one module overwrites only its own
JSONL file before the writer composes the final parquet. Modules can be
disabled via `--module_1.enabled=false` (and similarly for 2 and 3) to
disabled via `--plan.enabled=false` (and likewise `--interjections.enabled`
/ `--vqa.enabled`) to
test them in isolation.
## Validation/report checks before final write
@@ -134,18 +140,18 @@ Errors abort the writer (`--skip_validation=true` overrides for debugging).
## Paper inspirations per module
- **Module 1 — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
- **`plan` module — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
atom granularity ("pick up one piece of lettuce", "place bowl to box");
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
what" detail.
- **Module 1 — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
- **`plan` module — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
compression directive: keep only minimal relevant information; functional
outcomes preserved, specific attributes dropped.
- **Module 2 — interjections.** Hi Robot scenario taxonomy: negative task,
- **`interjections` module.** Hi Robot scenario taxonomy: negative task,
situated correction, specific constraint, preference. Speech is a
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
arguments:{text:...}}}]`).
- **Module 3 — VQA.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
- **`vqa` module.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
keypoints) and Steerable VLA Policies ([Zhao 2025](https://arxiv.org/abs/2509.07626))
multi-abstraction grounding. Pi0.7 also grounds answers across
@@ -157,9 +163,9 @@ references rather than rewriting from scratch.
## Compute and list-size estimates
Per episode, the pipeline issues O(`max_steps`) Module 1 calls,
O(`max_interjections_per_episode`) Module 2 calls, and
O(`vqa_emission_hz × episode_seconds`) Module 3 calls. With defaults
Per episode, the pipeline issues O(`max_steps`) `plan`-module calls,
O(`max_interjections_per_episode`) `interjections`-module calls, and
O(`vqa_emission_hz × episode_seconds`) `vqa`-module calls. With defaults
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
KB at most (parquet dictionary-encodes one entry per episode);
@@ -1,18 +1,35 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
Spawns one ``h200x2`` job that:
1. installs this branch of ``lerobot`` plus the annotation extras,
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
3. runs Module 1/2/3 across the dataset (per-camera VQA via PR 3471),
4. uploads the annotated dataset to ``--push_to_hub``.
3. runs the plan / interjections / vqa modules across the dataset,
4. uploads the annotated dataset back to ``--repo_id``.
``--repo_id`` is both the download source and, with ``--push_to_hub=true``,
the upload destination the job annotates the dataset in place.
Usage:
HF_TOKEN=hf_... uv run python examples/annotation/run_hf_job.py
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
Adjust ``CMD`` below to point at your own dataset / target hub repo.
Adjust ``CMD`` below to point at your own dataset.
"""
import os
@@ -36,7 +53,9 @@ CMD = (
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
"export VLLM_VIDEO_BACKEND=pyav && "
"lerobot-annotate "
"--repo_id=imstevenpmwork/super_poulain_draft "
# The dataset to annotate; also the push destination (annotate in place).
"--repo_id=<your-org>/<your-dataset> "
"--push_to_hub=true "
"--vlm.backend=openai "
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
"--vlm.parallel_servers=2 "
@@ -50,11 +69,10 @@ CMD = (
"--executor.episode_parallelism=32 "
"--vlm.chat_template_kwargs='{enable_thinking: false}' "
"--vlm.camera_key=observation.images.wrist "
"--module_1.frames_per_second=1.0 "
"--module_1.use_video_url=true "
"--module_1.use_video_url_fps=1.0 "
"--module_3.K=1 --module_3.vqa_emission_hz=0.2 "
"--push_to_hub=pepijn223/super_poulain_qwen36moe-3"
"--plan.frames_per_second=1.0 "
"--plan.use_video_url=true "
"--plan.use_video_url_fps=1.0 "
"--vqa.K=1 --vqa.vqa_emission_hz=0.2"
)
job = run_job(
+1 -1
View File
@@ -205,7 +205,7 @@ peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# on Linux, with a transformers fallback elsewhere; openai is the default
# backend and talks to any OpenAI-compatible server (``vllm serve`` /
# ``transformers serve`` / hosted endpoints). Distributed execution is
# delegated to Hugging Face Jobs (see examples/annotation/run_hf_job.py).
# delegated to Hugging Face Jobs (see examples/annotations/run_hf_job.py).
annotations = [
"lerobot[dataset]",
"lerobot[transformers-dep]",
@@ -19,9 +19,9 @@
The pipeline is decomposed into three independently runnable modules whose
outputs are staged per-episode before a final parquet rewrite:
- :mod:`.modules.plan_subtasks_memory` (Module 1) — persistent styles
- :mod:`.modules.interjections_and_speech` (Module 2) — event styles + speech
- :mod:`.modules.general_vqa` (Module 3) — event-style VQA pairs
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
"""
from .config import AnnotationPipelineConfig
@@ -22,12 +22,12 @@ from typing import Any
@dataclass
class Module1Config:
"""Module 1: plan + subtasks + memory + task augmentation.
class PlanConfig:
"""``plan`` module: plan + subtasks + memory + task augmentation.
Module 1 attaches the whole episode as one Qwen-VL video block;
``max_video_frames`` only caps the frames packed in (a model-capacity
bound, not an annotation-logic knob).
The ``plan`` module attaches the whole episode as one Qwen-VL video
block; ``max_video_frames`` only caps the frames packed in (a
model-capacity bound, not an annotation-logic knob).
"""
enabled: bool = True
@@ -39,8 +39,8 @@ class Module1Config:
# When to derive the task from the video instead of using
# ``record.episode_task``: ``off``, ``if_short`` (short / placeholder /
# missing canonical task), or ``always``. The derived task replaces the
# canonical one for every Module-1 prompt; ``meta/tasks.parquet`` is
# never modified.
# canonical one for every ``plan``-module prompt; ``meta/tasks.parquet``
# is never modified.
derive_task_from_video: str = "if_short"
derive_task_min_words: int = 3
@@ -51,21 +51,22 @@ class Module1Config:
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
# When True (and backend supports it, e.g. ``openai``), Module 1 sends a
# ``video_url`` block pointing at a per-episode mp4 subclip and lets the
# server sample frames at ``use_video_url_fps``.
# When True (and backend supports it, e.g. ``openai``), the ``plan``
# module sends a ``video_url`` block pointing at a per-episode mp4
# subclip and lets the server sample frames at ``use_video_url_fps``.
use_video_url: bool = False
use_video_url_fps: float = 1.0
@dataclass
class Module2Config:
"""Module 2: interjections + paired speech."""
class InterjectionsConfig:
"""``interjections`` module: interjections + paired speech."""
enabled: bool = True
# Each interjection emits a paired ``(interjection, speech)`` event row
# and triggers a ``plan`` refresh at the same timestamp via Module 1.
# and triggers a ``plan`` refresh at the same timestamp via the
# ``plan`` module.
max_interjections_per_episode: int = 3
interjection_min_t: float = 2.0
@@ -77,8 +78,8 @@ class Module2Config:
@dataclass
class Module3Config:
"""Module 3: general VQA."""
class VqaConfig:
"""``vqa`` module: general VQA."""
enabled: bool = True
vqa_emission_hz: float = 1.0
@@ -161,6 +162,8 @@ class AnnotationPipelineConfig:
revisions of the same dataset live in separate copies.
"""
# Hub dataset id. Used as the download source when ``root`` is unset,
# and as the destination repo when ``push_to_hub`` is enabled.
repo_id: str | None = None
root: Path | None = None
@@ -169,9 +172,9 @@ class AnnotationPipelineConfig:
seed: int = 1729
module_1: Module1Config = field(default_factory=Module1Config)
module_2: Module2Config = field(default_factory=Module2Config)
module_3: Module3Config = field(default_factory=Module3Config)
plan: PlanConfig = field(default_factory=PlanConfig)
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
vqa: VqaConfig = field(default_factory=VqaConfig)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
@@ -179,8 +182,9 @@ class AnnotationPipelineConfig:
skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None
# Upload the annotated dataset to the Hugging Face Hub when set.
push_to_hub: str | None = None
# When True, upload the annotated dataset back to ``repo_id`` on the
# Hugging Face Hub. ``repo_id`` must be set for this to take effect.
push_to_hub: bool = False
push_private: bool = False
push_commit_message: str | None = None
@@ -17,19 +17,20 @@
The executor plans **six phases** in the dependency order from the plan:
phase 1: Module 1 (plan + subtasks + memory)
phase 2: Module 2 (interjections + speech)
phase 3: Module 1 plan-update pass — re-runs plan emission at every
phase 1: ``plan`` module (plan + subtasks + memory)
phase 2: ``interjections`` module (interjections + speech)
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: Module 3 (VQA)
phase 4: ``vqa`` module (VQA)
phase 5: validator
phase 6: writer
Phase 3 is why Module 1 must be re-entered after Module 2 — to refresh
``plan`` rows at interjection timestamps.
Phase 3 is why the ``plan`` module must be re-entered after the
``interjections`` module — to refresh ``plan`` rows at interjection
timestamps.
Distributed execution is provided by Hugging Face Jobs (see
``examples/annotation/run_hf_job.py``); the runner inside the job
``examples/annotations/run_hf_job.py``); the runner inside the job
invokes ``lerobot-annotate`` which uses this in-process executor.
Episode-level concurrency is controlled by
``ExecutorConfig.episode_parallelism``.
@@ -38,6 +39,8 @@ Episode-level concurrency is controlled by
from __future__ import annotations
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@@ -71,7 +74,7 @@ class PipelineRunSummary:
@dataclass
class Executor:
"""Run all four phases over a dataset root in-process.
"""Run all six phases over a dataset root in-process.
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
(a thread pool); cluster-level concurrency comes from running this
@@ -80,9 +83,9 @@ class Executor:
"""
config: AnnotationPipelineConfig
module_1: Any # PlanSubtasksMemoryModule
module_2: Any # InterjectionsAndSpeechModule
module_3: Any # GeneralVqaModule
plan: Any # PlanSubtasksMemoryModule
interjections: Any # InterjectionsAndSpeechModule
vqa: Any # GeneralVqaModule
writer: LanguageColumnsWriter
validator: StagingValidator
@@ -99,16 +102,16 @@ class Executor:
phases: list[PhaseResult] = []
# Phase 1: Module 1 (plan + subtasks + memory)
phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1))
# Phase 2: Module 2 (interjections + speech). Module 2 reads
# Module 1's subtask rows from the same staging tree to ground
# the interjection prompt in the correct local subtask.
phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2))
# Phase 3: Module 1 plan-update pass at interjection timestamps.
# Phase 1: ``plan`` module (plan + subtasks + memory)
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
# Phase 2: ``interjections`` module (interjections + speech). It
# reads the ``plan`` module's subtask rows from the same staging
# tree to ground the interjection prompt in the correct local subtask.
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: Module 3 (VQA)
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
# Phase 4: ``vqa`` module (VQA)
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
print("[annotate] running validator...", flush=True)
report = self.validator.validate(records, staging_dir)
@@ -135,50 +138,37 @@ class Executor:
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
cast against the old schema and fail on the extra parquet columns.
"""
import json # noqa: PLC0415
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
info_path = root / "meta" / "info.json"
if not info_path.exists():
return
try:
info = json.loads(info_path.read_text())
info = load_info(root)
except Exception as exc: # noqa: BLE001
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
return
changed = False
features = info.get("features")
if not isinstance(features, dict):
features = {}
merged_features = {**features, **language_feature_info()}
if merged_features != features:
info["features"] = merged_features
merged_features = {**info.features, **language_feature_info()}
if merged_features != info.features:
info.features = merged_features
changed = True
existing = info.get("tools")
if not isinstance(existing, list):
existing = []
existing = info.tools or []
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
merged = list(existing)
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
merged.append(SAY_TOOL_SCHEMA)
if merged != existing:
info["tools"] = merged
info.tools = [*existing, SAY_TOOL_SCHEMA]
changed = True
if changed:
# Atomic replace — info.json is load-bearing for dataset
# metadata, so a crash mid-write would brick the dataset.
tmp_info = info_path.with_suffix(info_path.suffix + ".tmp")
tmp_info.write_text(json.dumps(info, indent=2))
tmp_info.replace(info_path)
write_info(info, root)
print(
"[annotate] meta/info.json: "
f"language_features={list(language_feature_info())}, "
f"tools={[t['function']['name'] for t in merged]}",
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
flush=True,
)
@@ -189,9 +179,6 @@ class Executor:
staging_dir: Path,
module: Any,
) -> PhaseResult:
import time as _time # noqa: PLC0415
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
if not module.enabled:
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
@@ -201,14 +188,14 @@ class Executor:
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
flush=True,
)
t0 = _time.time()
t0 = time.time()
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
i, record = idx_record
ep_start = _time.time()
ep_start = time.time()
staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging)
return i, record.episode_index, _time.time() - ep_start
return i, record.episode_index, time.time() - ep_start
processed = 0
if parallelism == 1:
@@ -230,38 +217,39 @@ class Executor:
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
flush=True,
)
total = _time.time() - t0
total = time.time() - t0
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
def _run_plan_update_phase( # noqa: PLR0915
self, records: list[EpisodeRecord], staging_dir: Path
) -> PhaseResult:
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
therefore calls back into Module 1 with the interjection timestamps so
Module 1's existing prompt path is reused.
The ``plan`` module owns the prompt; the ``interjections`` module
produced the timestamps. This phase therefore calls back into the
``plan`` module with the interjection timestamps so its existing
prompt path is reused.
"""
if not self.module_1.enabled or not self.module_2.enabled:
if not self.plan.enabled or not self.interjections.enabled:
return PhaseResult(
name="module_1_plan_update", episodes_processed=0, episodes_skipped=len(records)
name="plan_update", episodes_processed=0, episodes_skipped=len(records)
)
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
interjection_rows = [
row for row in staging.read("module_2") if row.get("style") == "interjection"
row for row in staging.read("interjections") if row.get("style") == "interjection"
]
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
if interjection_times:
self.module_1.run_plan_updates(record, staging, interjection_times, interjection_texts)
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
processed += 1
# Episodes without any interjections are skipped (no plan refresh
# needed); count them so the summary's processed+skipped == total.
return PhaseResult(
name="module_1_plan_update",
name="plan_update",
episodes_processed=processed,
episodes_skipped=len(records) - processed,
)
@@ -49,7 +49,7 @@ class FrameProvider(Protocol):
Empty list if the camera is unavailable. ``camera_key=None`` falls back
to the provider's default camera so existing single-camera callers
(Module 1, Module 2) keep working unchanged.
(the ``plan`` and ``interjections`` modules) keep working unchanged.
"""
def video_for_episode(
@@ -100,10 +100,11 @@ def null_provider() -> FrameProvider:
class VideoFrameProvider:
"""Decodes frames from the dataset's ``observation.images.*`` streams.
By default the *first* camera key is used for Module 1 (subtask
decomposition) and Module 2 (interjection scenarios) — those prompts care
about *what is happening*, not which angle. Module 3 (VQA) instead
iterates over every camera in :attr:`camera_keys` so each frame's
By default the *first* camera key is used for the ``plan`` module
(subtask decomposition) and the ``interjections`` module (interjection
scenarios) — those prompts care about *what is happening*, not which
angle. The ``vqa`` module instead iterates over every camera in
:attr:`camera_keys` so each frame's
grounded answer (bbox/keypoint/...) is tagged with the camera it was
grounded against.
@@ -112,7 +113,7 @@ class VideoFrameProvider:
``video_for_episode`` to read a non-default stream.
Caches up to ``cache_size`` decoded frames per process to keep
co-timestamped Module 2 + Module 1 plan-update calls cheap.
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
"""
root: Path
@@ -122,7 +123,7 @@ class VideoFrameProvider:
_meta: Any = field(default=None, init=False, repr=False)
_cache: dict = field(default_factory=dict, init=False, repr=False)
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
# Pipeline runs Module 1/2/3 phases under a ThreadPoolExecutor (see
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
# one-shot warn flag against concurrent updates from worker threads.
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
@@ -131,11 +132,10 @@ class VideoFrameProvider:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
# ``camera_keys`` covers both image- and video-stored cameras
# (``video_keys`` is video-only). Some datasets declare cameras with
# ``dtype=image``, which would otherwise look empty here and silently
# disable Module 3 even though the videos are there.
keys = list(getattr(self._meta, "camera_keys", None) or self._meta.video_keys or [])
# ``camera_keys`` covers both image- and video-stored cameras and is
# always defined on the metadata (``[]`` in the worst case), so it is
# the single source we need here.
keys = list(self._meta.camera_keys)
# Last-resort fallback: if metadata didn't surface anything but the
# caller explicitly named a camera (``--vlm.camera_key=...``), trust
# them — the key is by definition known to exist on the dataset.
@@ -275,10 +275,10 @@ class VideoFrameProvider:
try:
return _decode_pyav_direct(video_path, shifted, self.tolerance_s)
except Exception as exc:
# Log loudly the first time decoding fails so silent
# Module-3-no-op (every prompt skipped because frames_at returned
# []) is debuggable from the job log instead of post-hoc parquet
# inspection. Subsequent failures stay quiet.
# Log loudly the first time decoding fails so a silent
# vqa-module no-op (every prompt skipped because frames_at
# returned []) is debuggable from the job log instead of
# post-hoc parquet inspection. Subsequent failures stay quiet.
with self._lock:
already_warned = getattr(self, "_warned_decode_fail", False)
if not already_warned:
@@ -13,10 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 3: general VQA at a timed cadence.
"""``vqa`` module: general VQA at a timed cadence.
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
emission. For datasets with multiple cameras, every emission tick produces
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
consecutive frames, and every anchored frame gets its own VQA pair. Each
pair is grounded on that single anchor frame — there is no per-pair frame
window. For datasets with multiple cameras, every anchored frame produces
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
generated against that camera's frame and stamped with the matching
``camera`` field on the emitted rows. The resolver disambiguates via
@@ -26,7 +28,7 @@ per camera (see ``recipes/pi05_hirobot.yaml``).
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
Question types covered (per the plan's Module 3 table): bbox, keypoint,
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
count, attribute, spatial. The assistant's ``content`` is a JSON string
whose schema depends on the question type. Malformed JSON triggers one
retry inside :meth:`VlmClient.generate_json`.
@@ -35,12 +37,13 @@ retry inside :meth:`VlmClient.generate_json`.
from __future__ import annotations
import json
import logging
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import Module3Config
from ..config import VqaConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
@@ -89,7 +92,7 @@ class GeneralVqaModule:
"""Emit grounded VQA pairs at a timed cadence."""
vlm: VlmClient
config: Module3Config
config: VqaConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@@ -99,7 +102,7 @@ class GeneralVqaModule:
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
if not record.frame_timestamps:
staging.write("module_3", [])
staging.write("vqa", [])
return
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
anchor_idx = _emission_anchor_indices(
@@ -111,17 +114,15 @@ class GeneralVqaModule:
# untagged rows that would fail validation. Surface a loud one-
# time warning so this is never silently a no-op.
if not getattr(self, "_warned_no_camera", False):
import logging # noqa: PLC0415
logging.getLogger(__name__).warning(
"Module 3 (VQA) found no cameras on the frame provider — "
"vqa module found no cameras on the frame provider — "
"every episode will emit zero VQA rows. Check that the "
"dataset declares observation.images.* features in "
"meta/info.json; passing --vlm.camera_key=<key> at the "
"CLI now also seeds the cameras list as a fallback."
)
self._warned_no_camera = True
staging.write("module_3", [])
staging.write("vqa", [])
return
# Build all messages first (one per (frame, camera)), then issue them
@@ -140,13 +141,13 @@ class GeneralVqaModule:
per_call.append((ts, camera, qtype, messages))
if not per_call:
staging.write("module_3", [])
staging.write("vqa", [])
return
results = self.vlm.generate_json([m for _, _, _, m in per_call])
rows: list[dict[str, Any]] = []
for (ts, camera, _qtype, _messages), result in zip(per_call, results):
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
qa = self._postprocess(result)
if qa is None:
continue
@@ -171,10 +172,10 @@ class GeneralVqaModule:
"tool_calls": None,
}
)
staging.write("module_3", rows)
staging.write("vqa", rows)
def _target_cameras(self) -> list[str]:
"""Return the cameras Module 3 should iterate per emission tick.
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
Defaults to every camera the provider exposes. Datasets with no
cameras (or test/null providers) yield an empty list, which makes
@@ -214,17 +215,6 @@ class GeneralVqaModule:
return None
return question.strip(), answer
def _generate_one(
self,
record: EpisodeRecord,
question_type: str,
frame_timestamp: float,
camera_key: str,
) -> tuple[str, dict[str, Any]] | None:
messages = self._build_messages(record, question_type, frame_timestamp, camera_key)
result = self.vlm.generate_json([messages])[0]
return self._postprocess(result)
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
"""Return True if any user content block is a populated image block."""
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 2: interjections + paired speech (EVENT styles + speech atoms).
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
Two sub-passes:
@@ -26,8 +26,8 @@ Two sub-passes:
speech atom (role:assistant, style:None, tool_calls=[say(...)])
Both rows go in ``language_events`` at the same timestamp.
Module 1's :meth:`run_plan_updates` reuses Module 2's interjection
timestamps to refresh the ``plan`` row at the same instant.
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
interjection timestamps to refresh the ``plan`` row at the same instant.
"""
from __future__ import annotations
@@ -37,7 +37,7 @@ from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import Module2Config
from ..config import InterjectionsConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
@@ -51,7 +51,7 @@ class InterjectionsAndSpeechModule:
"""Generate task-start speech and mid-episode interjection/speech pairs."""
vlm: VlmClient
config: Module2Config
config: InterjectionsConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@@ -66,13 +66,13 @@ class InterjectionsAndSpeechModule:
initial = self._initial_speech(record)
if initial:
rows.append(speech_atom(t0, initial))
# Pull Module 1's subtask spans for this episode so the
# Pull the ``plan`` module's subtask spans for this episode so the
# interjection prompt can ground itself in the actual current
# subtask at each chosen timestamp. Module 1 ran first.
# subtask at each chosen timestamp. The ``plan`` module ran first.
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
subtask_spans = reconstruct_subtask_spans(staging.read("module_1"), episode_end_t=episode_end_t)
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
rows.extend(self._mid_episode_interjections(record, subtask_spans))
staging.write("module_2", rows)
staging.write("interjections", rows)
@staticmethod
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 1: subtask decomposition + plan + memory (PERSISTENT styles)."""
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
from __future__ import annotations
@@ -22,7 +22,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from ..config import Module1Config
from ..config import PlanConfig
from ..frames import (
FrameProvider,
VideoFrameProvider,
@@ -46,13 +46,13 @@ class PlanSubtasksMemoryModule:
(snapped to an exact frame).
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
timestamp via :meth:`run_plan_updates` (called by the executor after
Module 2 completes).
the ``interjections`` module completes).
- ``memory`` rows: emitted at each subtask boundary (= subtask start
timestamp from the second subtask onward).
"""
vlm: VlmClient
config: Module1Config
config: PlanConfig
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
@@ -61,14 +61,14 @@ class PlanSubtasksMemoryModule:
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
# Resolve the task that drives every other Module-1 prompt. May be
# the canonical ``record.episode_task`` (default), or a fresh
# Resolve the task that drives every other ``plan``-module prompt.
# May be the canonical ``record.episode_task`` (default), or a fresh
# description derived from the video when the canonical task is
# empty / placeholder / forced-off (see Module1Config.derive_task_*).
# empty / placeholder / forced-off (see PlanConfig.derive_task_*).
effective_task = self._resolve_effective_task(record)
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
# PR 1 renderer rotates ``${task}`` deterministically through them
# so the policy sees diverse phrasings during training.
# message renderer rotates ``${task}`` deterministically through
# them so the policy sees diverse phrasings during training.
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
if self.config.n_task_rephrasings > 0 and effective_task:
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
@@ -134,7 +134,7 @@ class PlanSubtasksMemoryModule:
}
)
prior_memory = mem_text
staging.write("module_1", rows)
staging.write("plan", rows)
# ------------------------------------------------------------------
# Task derivation + rephrasings
@@ -156,7 +156,7 @@ class PlanSubtasksMemoryModule:
)
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
"""Decide which task string drives Module 1 for this episode.
"""Decide which task string drives the ``plan`` module for this episode.
Returns the user-supplied ``record.episode_task`` unless
``derive_task_from_video`` says otherwise (see config docstring).
@@ -182,7 +182,7 @@ class PlanSubtasksMemoryModule:
return task.lower() in self._PLACEHOLDER_TASKS
# ------------------------------------------------------------------
# VLM call helpers (factored out: every Module-1 prompt below follows
# VLM call helpers (factored out: every ``plan``-module prompt below follows
# the same "build messages → single VLM call → pull a named field"
# shape, only differing in field name + post-processing).
# ------------------------------------------------------------------
@@ -258,7 +258,7 @@ class PlanSubtasksMemoryModule:
(the previous version told the model "an interjection happened"
without telling it what the user said).
"""
existing = staging.read("module_1")
existing = staging.read("plan")
# Pass the episode's last frame timestamp so the final subtask
# span is closed (otherwise its ``end`` equals its ``start``,
# zero duration, and the "current subtask at refresh_t" lookup
@@ -289,7 +289,7 @@ class PlanSubtasksMemoryModule:
"tool_calls": None,
}
)
staging.write("module_1", new_rows)
staging.write("plan", new_rows)
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
if record.row_count == 0 or not record.frame_timestamps:
@@ -38,6 +38,7 @@ from typing import Any
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import load_tasks
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
@@ -83,8 +84,9 @@ def reconstruct_subtask_spans(
which is what downstream consumers (memory, interjection boundary
selection) expect.
Used by Module 1 (plan-update pass) and Module 2 (interjection
anchoring), which both need the same span shape.
Used by the ``plan`` module (plan-update pass) and the
``interjections`` module (interjection anchoring), which both need the
same span shape.
"""
sorted_rows = sorted(
(r for r in rows if r.get("style") == "subtask"),
@@ -105,8 +107,9 @@ def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
"""Snap an arbitrary float to the nearest exact source frame timestamp.
Modules use this when emitting event-style rows so the row's
timestamp matches a real parquet frame (event rows must land on an
exact frame, see PR 1's "exact event matching" rule).
timestamp matches a real parquet frame: event rows must land on an
exact frame, otherwise the per-frame event lookup the writer does
would never match them.
"""
if not frame_timestamps:
return float(t)
@@ -115,14 +118,17 @@ def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
def _load_tasks_lookup(root: Path) -> dict[int, str]:
tasks_path = root / DEFAULT_TASKS_PATH
if not tasks_path.exists():
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
Returns an empty dict when the file is absent — the task description is
derived later from the video if needed. Reuses the library-level
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
frame indexed by task string with a ``task_index`` column.
"""
if not (root / DEFAULT_TASKS_PATH).exists():
return {}
table = pq.read_table(tasks_path)
cols = {name: table.column(name).to_pylist() for name in table.column_names}
if "task_index" in cols and "task" in cols:
return dict(zip(cols["task_index"], cols["task"], strict=True))
raise ValueError(f"meta/tasks.parquet at {tasks_path} missing 'task_index' or 'task'")
tasks = load_tasks(root)
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
@@ -36,9 +36,9 @@ from typing import Any
ModuleName = str
_MODULES: tuple[ModuleName, ...] = (
"module_1",
"module_2",
"module_3",
"plan",
"interjections",
"vqa",
)
@@ -15,7 +15,7 @@
# limitations under the License.
"""Pre-write validation against staged outputs.
Runs after Modules 13 have all written their per-episode artifacts but
Runs after all three modules have written their per-episode artifacts but
*before* the writer rewrites parquet shards. The validator never touches
parquet; it only inspects the staging tree and the source frame timestamps
exposed by :class:`EpisodeRecord`.
@@ -218,11 +218,11 @@ class StagingValidator:
except ValueError:
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
return
if module == "module_1" and target_col != LANGUAGE_PERSISTENT:
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
report.add_error(
f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)"
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
)
if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS:
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
report.add_error(
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
)
@@ -32,10 +32,20 @@ The client speaks one method, :meth:`VlmClient.generate_json`, which:
from __future__ import annotations
import atexit
import base64
import io
import json
import os
import shlex
import signal
import subprocess
import sys
import threading
import time
import urllib.request
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Protocol
@@ -212,10 +222,8 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
# convolution kernels — slower but functional.
import os as _os # noqa: PLC0415
if _os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
import torch as _torch # noqa: PLC0415
if os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
_torch.backends.cudnn.enabled = False
llm_kwargs: dict[str, Any] = {
@@ -259,9 +267,7 @@ def _make_transformers_client(config: VlmConfig) -> VlmClient:
"for VL models."
)
processor = AutoProcessor.from_pretrained(config.model_id, trust_remote_code=config.trust_remote_code)
import os as _os # noqa: PLC0415
use_accelerate = _os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
use_accelerate = os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
# post-load dispatch path (the alloc fails in accelerate's hook setup
# even with TBs of host RAM). Default to manual: load on CPU with
@@ -276,7 +282,7 @@ def _make_transformers_client(config: VlmConfig) -> VlmClient:
trust_remote_code=config.trust_remote_code,
)
else:
import torch as _torch # noqa: PLC0415
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
model = auto_cls.from_pretrained(
config.model_id,
@@ -390,8 +396,6 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
if len(batch) <= 1 or config.client_concurrency <= 1:
return [_one_call(messages, max_tok, temp) for messages in batch]
# Parallel fan-out — vllm batches these on the server side.
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
max_workers = min(config.client_concurrency, len(batch))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
@@ -411,15 +415,6 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
Returns the list of ``api_base`` URLs the client should round-robin
across.
"""
import atexit # noqa: PLC0415
import os as _os # noqa: PLC0415
import shlex # noqa: PLC0415
import signal # noqa: PLC0415
import subprocess # noqa: PLC0415
import sys # noqa: PLC0415
import threading # noqa: PLC0415
import time # noqa: PLC0415
n = config.parallel_servers
api_bases: list[str] = []
procs: list[subprocess.Popen] = []
@@ -449,7 +444,7 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
for i in range(n):
port = config.serve_port + i
gpu = i % num_gpus
env = _os.environ.copy()
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
api_base = f"http://localhost:{port}/v1"
@@ -522,8 +517,6 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
def _server_is_up(api_base: str) -> bool:
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
import urllib.request # noqa: PLC0415
url = api_base.rstrip("/") + "/models"
# ``api_base`` is the user-configured local-server URL we just spawned
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
@@ -546,14 +539,6 @@ def _spawn_inference_server(config: VlmConfig) -> str:
Returns the full ``api_base`` URL the OpenAI client should use.
"""
import atexit # noqa: PLC0415
import shlex # noqa: PLC0415
import signal # noqa: PLC0415
import subprocess # noqa: PLC0415
import sys # noqa: PLC0415
import threading # noqa: PLC0415
import time # noqa: PLC0415
cmd = config.serve_command
if not cmd:
cmd = (
@@ -695,8 +680,6 @@ def _to_openai_messages(
def _file_to_data_url(path: str) -> str:
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
import base64 # noqa: PLC0415
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
return f"data:video/mp4;base64,{b64}"
@@ -704,9 +687,6 @@ def _file_to_data_url(path: str) -> str:
def _pil_to_data_url(image: Any) -> str:
"""Encode a PIL.Image as a base64 data URL."""
import base64 # noqa: PLC0415
import io # noqa: PLC0415
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
@@ -29,7 +29,7 @@ For every episode the writer:
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
struct (PR 1) for every speech atom. The tool *schema* (the description
struct for every speech atom. The tool *schema* (the description
of the ``say`` function and its parameters) is a fixed code constant
``SAY_TOOL_SCHEMA`` below and downstream chat-template consumers import
it directly rather than reading a redundant per-row column.
@@ -69,7 +69,7 @@ from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
# Tool schema constants moved to lerobot.datasets.language in PR 1 — single
# Tool schema constants live in lerobot.datasets.language — single
# source of truth. Re-exported here so existing imports
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
# keep working.
@@ -309,8 +309,8 @@ class LanguageColumnsWriter:
# uses `pa.json_()` for the `tool_calls` element type, which
# `pa.array(..., type=...)` cannot materialize from Python lists on
# current pyarrow versions. The inferred schema round-trips through
# parquet and `LeRobotDataset` correctly — see PR 1's
# `tests/datasets/test_language.py` which exercises the same flow.
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
# exercises the same flow.
persistent_arr = pa.array(persistent)
events_arr = pa.array(events)
+18 -16
View File
@@ -24,7 +24,7 @@ Example:
--root=/path/to/dataset \\
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
For distributed runs, see ``examples/annotation/run_hf_job.py``.
For distributed runs, see ``examples/annotations/run_hf_job.py``.
"""
import logging
@@ -65,27 +65,27 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
vlm = make_vlm_client(cfg.vlm)
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key)
# Surface the resolved cameras up front so silent Module-3-no-op
# regressions are obvious in job output rather than discovered post-hoc
# by counting parquet rows.
# Surface the resolved cameras up front so a silent vqa-module no-op
# is obvious in job output rather than discovered post-hoc by counting
# parquet rows.
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
logger.info(
"annotate: frame_provider default camera=%r, all cameras=%s",
getattr(frame_provider, "camera_key", None),
cam_keys,
)
if cfg.module_3.enabled and not cam_keys:
if cfg.vqa.enabled and not cam_keys:
logger.warning(
"annotate: Module 3 (VQA) is enabled but no cameras were "
"resolved — Module 3 will produce zero VQA rows. Check "
"annotate: the vqa module is enabled but no cameras were "
"resolved — it will produce zero VQA rows. Check "
"meta/info.json for observation.images.* features, or pass "
"--vlm.camera_key=<key> to seed the cameras list."
)
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1, frame_provider=frame_provider)
module_2 = InterjectionsAndSpeechModule(
vlm=vlm, config=cfg.module_2, seed=cfg.seed, frame_provider=frame_provider
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
interjections = InterjectionsAndSpeechModule(
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
)
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
writer = LanguageColumnsWriter()
validator = StagingValidator(
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
@@ -93,9 +93,9 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
executor = Executor(
config=cfg,
module_1=module_1,
module_2=module_2,
module_3=module_3,
plan=plan,
interjections=interjections,
vqa=vqa,
writer=writer,
validator=validator,
)
@@ -113,14 +113,16 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
logger.warning(w)
if cfg.push_to_hub:
if cfg.repo_id is None:
raise ValueError("--push_to_hub requires --repo_id (the dataset repo to push to).")
_push_to_hub(root, cfg)
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
"""Upload the annotated dataset directory to the Hugging Face Hub."""
"""Upload the annotated dataset directory back to ``cfg.repo_id`` on the Hub."""
from huggingface_hub import HfApi # noqa: PLC0415
repo_id = cfg.push_to_hub
repo_id = cfg.repo_id
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
api = HfApi()
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
+6 -67
View File
@@ -15,85 +15,24 @@
# limitations under the License.
"""Shared fixtures for annotation-pipeline tests.
Builds a minimal LeRobot-shaped dataset on disk so writer/validator tests
can exercise real parquet reads and writes without needing a checked-in
LFS dataset.
The on-disk dataset builder lives with the other dataset factories in
``tests/fixtures/dataset_factories.py`` (:func:`build_annotation_dataset`);
these fixtures only wire it into pytest.
"""
from __future__ import annotations
import json
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
def _make_episode_table(
episode_index: int,
num_frames: int,
*,
fps: int = 10,
task_index: int = 0,
) -> pa.Table:
timestamps = [round(i / fps, 6) for i in range(num_frames)]
frame_indices = list(range(num_frames))
return pa.Table.from_pydict(
{
"episode_index": [episode_index] * num_frames,
"frame_index": frame_indices,
"timestamp": timestamps,
"task_index": [task_index] * num_frames,
"subtask_index": [0] * num_frames, # legacy column the writer must drop
}
)
def _build_dataset(root: Path, episode_specs: list[tuple[int, int, str]], *, fps: int = 10) -> Path:
"""Create a fixture dataset under ``root``.
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
Each episode goes into its own ``data/chunk-000/file-{ep:03d}.parquet``
so the writer's per-shard rewrite path is exercised.
"""
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
tasks = {}
for episode_index, num_frames, task_text in episode_specs:
task_index = len(tasks)
if task_text not in tasks.values():
tasks[task_index] = task_text
else:
task_index = next(k for k, v in tasks.items() if v == task_text)
table = _make_episode_table(episode_index, num_frames, fps=fps, task_index=task_index)
path = data_dir / f"file-{episode_index:03d}.parquet"
pq.write_table(table, path)
meta_dir = root / "meta"
meta_dir.mkdir(parents=True, exist_ok=True)
tasks_table = pa.Table.from_pydict(
{
"task_index": list(tasks.keys()),
"task": list(tasks.values()),
}
)
pq.write_table(tasks_table, meta_dir / "tasks.parquet")
info = {
"codebase_version": "v3.1",
"fps": fps,
"total_episodes": len(episode_specs),
}
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
return root
from tests.fixtures.dataset_factories import build_annotation_dataset
@pytest.fixture
def fixture_dataset_root(tmp_path: Path) -> Path:
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
return _build_dataset(
return build_annotation_dataset(
tmp_path / "ds",
episode_specs=[
(0, 12, "Could you tidy the kitchen please?"),
@@ -105,7 +44,7 @@ def fixture_dataset_root(tmp_path: Path) -> Path:
@pytest.fixture
def single_episode_root(tmp_path: Path) -> Path:
return _build_dataset(
return build_annotation_dataset(
tmp_path / "ds_one",
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
fps=10,
+14 -37
View File
@@ -15,22 +15,19 @@
# limitations under the License.
"""Opt-in E2E smoke run for ``make annotation-e2e``.
Builds the same fixture used by the pytest suite, runs the full
annotation pipeline against it with a stub VLM, and prints a short report.
This is intentionally not a pytest test it exercises the CLI plumbing
without depending on conftest.py fixtures.
Builds the shared annotation fixture (:func:`build_annotation_dataset`),
runs the full annotation pipeline against it with a stub VLM, and prints a
short report. This is intentionally not a pytest test it exercises the
CLI plumbing but it reuses the same on-disk dataset builder as the pytest
fixtures so there is no duplicated fixture code.
"""
from __future__ import annotations
import json
import sys
import tempfile
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.modules import (
@@ -41,31 +38,7 @@ from lerobot.annotations.steerable_pipeline.modules import (
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
def _build_dataset(root: Path) -> Path:
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
n = 30
timestamps = [round(i / 10, 6) for i in range(n)]
table = pa.Table.from_pydict(
{
"episode_index": [0] * n,
"frame_index": list(range(n)),
"timestamp": timestamps,
"task_index": [0] * n,
"subtask_index": [0] * n,
}
)
pq.write_table(table, data_dir / "file-000.parquet")
meta = root / "meta"
meta.mkdir(parents=True, exist_ok=True)
pq.write_table(
pa.Table.from_pydict({"task_index": [0], "task": ["Pour water into the cup."]}),
meta / "tasks.parquet",
)
(meta / "info.json").write_text(json.dumps({"codebase_version": "v3.1", "fps": 10}))
return root
from tests.fixtures.dataset_factories import build_annotation_dataset
def _stub_responder(messages):
@@ -102,14 +75,18 @@ def _stub_responder(messages):
def main() -> int:
with tempfile.TemporaryDirectory() as tmp:
root = _build_dataset(Path(tmp) / "ds")
root = build_annotation_dataset(
Path(tmp) / "ds",
episode_specs=[(0, 30, "Pour water into the cup.")],
fps=10,
)
vlm = StubVlmClient(responder=_stub_responder)
cfg = AnnotationPipelineConfig()
executor = Executor(
config=cfg,
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1),
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed),
module_3=GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed),
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)
+16 -16
View File
@@ -23,9 +23,9 @@ from pathlib import Path
from typing import Any
from lerobot.annotations.steerable_pipeline.config import (
Module1Config,
Module2Config,
Module3Config,
InterjectionsConfig,
PlanConfig,
VqaConfig,
)
from lerobot.annotations.steerable_pipeline.modules import (
GeneralVqaModule,
@@ -84,11 +84,11 @@ def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path:
"Update the memory": {"memory": "wiped the counter once"},
},
)
module = PlanSubtasksMemoryModule(vlm=vlm, config=Module1Config())
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig())
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_1")
rows = staging.read("plan")
styles = {r["style"] for r in rows}
assert {"subtask", "plan", "memory"}.issubset(styles)
@@ -108,12 +108,12 @@ def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: P
)
module = InterjectionsAndSpeechModule(
vlm=vlm,
config=Module2Config(max_interjections_per_episode=0),
config=InterjectionsConfig(max_interjections_per_episode=0),
)
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_2")
rows = staging.read("interjections")
assert len(rows) == 1
only = rows[0]
assert only["role"] == "assistant"
@@ -151,7 +151,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
)
module = InterjectionsAndSpeechModule(
vlm=vlm,
config=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.2),
config=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.2),
seed=7,
)
record = next(iter_episodes(fixture_dataset_root))
@@ -161,7 +161,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
# production executor guarantees Module 1 ran first).
boundary_ts = float(record.frame_timestamps[len(record.frame_timestamps) // 2])
staging.write(
"module_1",
"plan",
[
{
"role": "assistant",
@@ -180,7 +180,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
],
)
module.run_episode(record, staging)
rows = staging.read("module_2")
rows = staging.read("interjections")
interjections = [r for r in rows if r["style"] == "interjection"]
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
@@ -198,14 +198,14 @@ def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_
vlm = make_canned_responder({"frame-grounded visual question": payload})
module = GeneralVqaModule(
vlm=vlm,
config=Module3Config(vqa_emission_hz=1.0, K=3),
config=VqaConfig(vqa_emission_hz=1.0, K=3),
seed=1,
frame_provider=_StubFrameProvider(cameras=("observation.images.top", "observation.images.wrist")),
)
record = next(iter_episodes(single_episode_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_3")
rows = staging.read("vqa")
# every vqa row must carry a camera tag and one of the configured cameras
for r in rows:
assert r["style"] == "vqa"
@@ -257,7 +257,7 @@ def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Pa
# call is the subtask one — keeps the assertions below focused on
# ``_generate_subtasks`` rather than fighting the order of unrelated
# text-only Module-1 sub-prompts.
config=Module1Config(max_video_frames=5, frames_per_second=10.0, n_task_rephrasings=0),
config=PlanConfig(max_video_frames=5, frames_per_second=10.0, n_task_rephrasings=0),
frame_provider=provider,
)
record = next(iter_episodes(fixture_dataset_root))
@@ -304,7 +304,7 @@ def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path,
provider = _StubFrameProvider()
module = GeneralVqaModule(
vlm=_spy_responder(captured, payload),
config=Module3Config(vqa_emission_hz=1.0, K=1),
config=VqaConfig(vqa_emission_hz=1.0, K=1),
seed=0,
frame_provider=provider,
)
@@ -336,14 +336,14 @@ def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_
vlm = make_canned_responder({"frame-grounded visual question": payload})
module = GeneralVqaModule(
vlm=vlm,
config=Module3Config(vqa_emission_hz=1.0, K=2),
config=VqaConfig(vqa_emission_hz=1.0, K=2),
seed=2,
frame_provider=_StubFrameProvider(),
)
record = next(iter_episodes(single_episode_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("module_3")
rows = staging.read("vqa")
for row in rows:
if row["role"] == "assistant" and row["style"] == "vqa":
decoded = json.loads(row["content"])
@@ -23,9 +23,9 @@ import pyarrow.parquet as pq
from lerobot.annotations.steerable_pipeline.config import (
AnnotationPipelineConfig,
Module1Config,
Module2Config,
Module3Config,
InterjectionsConfig,
PlanConfig,
VqaConfig,
)
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.modules import (
@@ -115,15 +115,15 @@ def _build_executor() -> Executor:
},
)
config = AnnotationPipelineConfig(
module_1=Module1Config(),
module_2=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.5),
module_3=Module3Config(vqa_emission_hz=1.0, K=2),
plan=PlanConfig(),
interjections=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.5),
vqa=VqaConfig(vqa_emission_hz=1.0, K=2),
)
return Executor(
config=config,
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=config.module_1),
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=config.module_2, seed=config.seed),
module_3=GeneralVqaModule(vlm=vlm, config=config.module_3, seed=config.seed),
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
writer=LanguageColumnsWriter(),
validator=StagingValidator(),
)
+6 -6
View File
@@ -34,7 +34,7 @@ def _validate(root: Path, staging_dir: Path):
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_3",
"vqa",
[
{
"role": "assistant",
@@ -53,7 +53,7 @@ def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_2",
"interjections",
[
speech_atom(0.0, "Got it."),
# interjection at 0.3s with NO paired speech
@@ -74,7 +74,7 @@ def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: P
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_1",
"plan",
[
{
"role": "assistant",
@@ -93,7 +93,7 @@ def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path,
],
)
EpisodeStaging(staging_dir, 0).write(
"module_2",
"interjections",
[
speech_atom(0.0, "Got it."),
speech_atom(0.4, "Replanning."),
@@ -115,11 +115,11 @@ def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path,
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
staging_dir = tmp_path / "stage"
EpisodeStaging(staging_dir, 0).write(
"module_1",
"plan",
[
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
],
)
report = _validate(fixture_dataset_root, staging_dir)
assert not report.ok
assert any("module_1 emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
assert any("plan emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
+17 -17
View File
@@ -35,17 +35,17 @@ def _stage_episode(
staging_dir: Path,
episode_index: int,
*,
module_1: list[dict] | None = None,
module_2: list[dict] | None = None,
module_3: list[dict] | None = None,
plan: list[dict] | None = None,
interjections: list[dict] | None = None,
vqa: list[dict] | None = None,
) -> None:
staging = EpisodeStaging(staging_dir, episode_index)
if module_1 is not None:
staging.write("module_1", module_1)
if module_2 is not None:
staging.write("module_2", module_2)
if module_3 is not None:
staging.write("module_3", module_3)
if plan is not None:
staging.write("plan", plan)
if interjections is not None:
staging.write("interjections", interjections)
if vqa is not None:
staging.write("vqa", vqa)
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
@@ -54,7 +54,7 @@ def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path)
_stage_episode(
staging_dir,
0,
module_1=[
plan=[
{
"role": "assistant",
"content": "grasp the sponge",
@@ -94,7 +94,7 @@ def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Pat
_stage_episode(
staging_dir,
0,
module_2=[
interjections=[
speech_atom(0.0, "Got it."),
{
"role": "user",
@@ -127,7 +127,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
_stage_episode(
staging_dir,
0,
module_1=[
plan=[
{
"role": "assistant",
"content": "do X",
@@ -150,7 +150,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
"tool_calls": None,
},
],
module_2=[
interjections=[
speech_atom(0.0, "OK"),
{
"role": "user",
@@ -161,7 +161,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
},
speech_atom(0.2, "Waiting"),
],
module_3=[
vqa=[
{
"role": "user",
"content": "where is the cup?",
@@ -201,7 +201,7 @@ def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_p
_stage_episode(
staging_dir,
0,
module_1=[
plan=[
{
"role": "assistant",
"content": "do X",
@@ -277,7 +277,7 @@ def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path:
_stage_episode(
staging_dir,
0,
module_1=[
plan=[
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
],
)
@@ -316,7 +316,7 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
_stage_episode(
staging_dir,
0,
module_1=[
plan=[
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
],
)
+61
View File
@@ -555,3 +555,64 @@ def lerobot_dataset_factory(
@pytest.fixture(scope="session")
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
def build_annotation_dataset(
root: Path,
episode_specs: list[tuple[int, int, str]],
*,
fps: int = 10,
) -> Path:
"""Build a minimal LeRobot-shaped dataset on disk for annotation tests.
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
Each episode is written to its own
``data/chunk-000/file-{ep:03d}.parquet`` so the writer's per-shard
rewrite path is exercised. The dataset carries the minimum
``meta/tasks.parquet`` + ``meta/info.json`` the reader / executor need;
it has no videos, so the modules fall back to text-only prompts.
Shared by the annotation-pipeline pytest fixtures (``tests/annotations/
conftest.py``) and the opt-in E2E smoke run so the fixture shape lives
in exactly one place.
"""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
tasks: dict[int, str] = {}
for episode_index, num_frames, task_text in episode_specs:
if task_text not in tasks.values():
tasks[len(tasks)] = task_text
task_index = next(k for k, v in tasks.items() if v == task_text)
frame = pd.DataFrame(
{
"episode_index": [episode_index] * num_frames,
"frame_index": list(range(num_frames)),
"timestamp": [round(i / fps, 6) for i in range(num_frames)],
"task_index": [task_index] * num_frames,
"subtask_index": [0] * num_frames, # legacy column the writer must drop
}
)
frame.to_parquet(data_dir / f"file-{episode_index:03d}.parquet", index=False)
# Canonical tasks frame: indexed by task string with a ``task_index``
# column, matching what ``lerobot.datasets.io_utils.load_tasks`` expects.
tasks_df = pd.DataFrame(
{"task_index": list(tasks.keys())},
index=pd.Index(list(tasks.values()), name="task"),
)
write_tasks(tasks_df, root)
write_json(
{
"codebase_version": "v3.1",
"fps": fps,
"features": {},
"total_episodes": len(episode_specs),
},
root / "meta" / "info.json",
)
return root