mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
review: address CarolinePascal feedback
- name the three modules everywhere (plan / interjections / vqa) instead of module_1/2/3 — config classes, config fields, executor params, staging keys and phase names now carry the module name - rename examples/annotation -> examples/annotations; add the Apache header to run_hf_job.py - drop the unused GeneralVqaModule._generate_one - remove "PR 1" references from comments/docstrings - frames.py: rely on the always-defined LeRobotDatasetMetadata.camera_keys - executor.py: read/write meta/info.json via load_info / write_info - reader.py: load meta/tasks.parquet via io_utils.load_tasks - make --push_to_hub a bool; push the annotated dataset back to --repo_id - move the on-disk test dataset builder into tests/fixtures (build_annotation_dataset); run_e2e_smoke reuses it - clarify in the docs that the vqa module grounds each pair on a single frame (K = per-tick anchor count) - hoist stdlib dynamic imports to module scope Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -11,14 +11,14 @@ Three modules write into a per-episode staging tree, then a single writer
|
||||
rewrites the data shards in place:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | -------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | Module 1 |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | Module 1 |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | Module 1 |
|
||||
| `task_aug` (rephrasings of canonical task) | `language_persistent` | Module 1 |
|
||||
| `interjection` | `language_events` | Module 2 |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 |
|
||||
| `vqa` (user / assistant pair) | `language_events` | Module 3 |
|
||||
| ------------------------------------------- | --------------------- | -------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of canonical task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections`|
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections`|
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet — the tool
|
||||
catalog lives at `meta/info.json["tools"]` instead (see
|
||||
@@ -45,20 +45,24 @@ uv run lerobot-annotate \
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
The pipeline attaches actual camera footage to every Module 1/2/3 prompt
|
||||
by default, decoded from the dataset's first `observation.images.*`
|
||||
stream. Override with `--vlm.camera_key=observation.images.<name>` to
|
||||
pin a specific viewpoint. Datasets with no video tracks fall back to
|
||||
text-only prompts automatically.
|
||||
The pipeline attaches actual camera footage to every `plan` /
|
||||
`interjections` / `vqa` prompt by default, decoded from the dataset's
|
||||
first `observation.images.*` stream. Override with
|
||||
`--vlm.camera_key=observation.images.<name>` to pin a specific
|
||||
viewpoint. Datasets with no video tracks fall back to text-only prompts
|
||||
automatically.
|
||||
|
||||
**Module 1 sees the whole episode as one video block.** Subtask
|
||||
**The `plan` module sees the whole episode as one video block.** Subtask
|
||||
decomposition gets a `{"type":"video", "video":[<frames>]}` block
|
||||
covering the entire demonstration; Qwen-VL pools temporally on its own
|
||||
and decides where to cut. There is no keyframe stride or count knob —
|
||||
`--module_1.max_video_frames` (default 128) only caps the frames packed
|
||||
into the video block as a model-capacity bound. Module 2 attaches a
|
||||
short window of frames around the interjection timestamp; Module 3
|
||||
attaches the exact emission frame to each VQA pair.
|
||||
`--plan.max_video_frames` (default 128) only caps the frames packed
|
||||
into the video block as a model-capacity bound. The `interjections`
|
||||
module attaches a short window of frames straddling the interjection
|
||||
timestamp. The `vqa` module grounds each VQA pair on a single frame —
|
||||
its `--vqa.K` knob sets how many consecutive frames each emission tick
|
||||
anchors, and every anchored frame gets its own VQA pair on that one
|
||||
frame (there is no per-pair frame window).
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
@@ -67,15 +71,16 @@ Distributed annotation is delegated to
|
||||
ships a launcher script you copy and edit for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotation/run_hf_job.py
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`examples/annotation/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotation/run_hf_job.py)
|
||||
[`examples/annotations/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
spawns one `h200x2` job that:
|
||||
|
||||
1. installs the branch under test plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) for the chosen model,
|
||||
3. runs Modules 1 / 2 / 3 across the dataset via `lerobot-annotate`,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
via `lerobot-annotate`,
|
||||
4. uploads the annotated dataset to `--push_to_hub`.
|
||||
|
||||
To target a different dataset, model, or hub repo, edit the `CMD` block
|
||||
@@ -115,7 +120,8 @@ Each module writes its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
|
||||
prompt iteration cheap — re-running one module overwrites only its own
|
||||
JSONL file before the writer composes the final parquet. Modules can be
|
||||
disabled via `--module_1.enabled=false` (and similarly for 2 and 3) to
|
||||
disabled via `--plan.enabled=false` (and likewise `--interjections.enabled`
|
||||
/ `--vqa.enabled`) to
|
||||
test them in isolation.
|
||||
|
||||
## Validation/report checks before final write
|
||||
@@ -134,18 +140,18 @@ Errors abort the writer (`--skip_validation=true` overrides for debugging).
|
||||
|
||||
## Paper inspirations per module
|
||||
|
||||
- **Module 1 — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
- **`plan` module — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
atom granularity ("pick up one piece of lettuce", "place bowl to box");
|
||||
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
|
||||
what" detail.
|
||||
- **Module 1 — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||
- **`plan` module — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||
compression directive: keep only minimal relevant information; functional
|
||||
outcomes preserved, specific attributes dropped.
|
||||
- **Module 2 — interjections.** Hi Robot scenario taxonomy: negative task,
|
||||
- **`interjections` module.** Hi Robot scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
|
||||
arguments:{text:...}}}]`).
|
||||
- **Module 3 — VQA.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||
- **`vqa` module.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies ([Zhao 2025](https://arxiv.org/abs/2509.07626))
|
||||
multi-abstraction grounding. Pi0.7 also grounds answers across
|
||||
@@ -157,9 +163,9 @@ references rather than rewriting from scratch.
|
||||
|
||||
## Compute and list-size estimates
|
||||
|
||||
Per episode, the pipeline issues O(`max_steps`) Module 1 calls,
|
||||
O(`max_interjections_per_episode`) Module 2 calls, and
|
||||
O(`vqa_emission_hz × episode_seconds`) Module 3 calls. With defaults
|
||||
Per episode, the pipeline issues O(`max_steps`) `plan`-module calls,
|
||||
O(`max_interjections_per_episode`) `interjections`-module calls, and
|
||||
O(`vqa_emission_hz × episode_seconds`) `vqa`-module calls. With defaults
|
||||
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
|
||||
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
|
||||
KB at most (parquet dictionary-encodes one entry per episode);
|
||||
|
||||
@@ -1,18 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||
|
||||
Spawns one ``h200x2`` job that:
|
||||
|
||||
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs Module 1/2/3 across the dataset (per-camera VQA via PR 3471),
|
||||
4. uploads the annotated dataset to ``--push_to_hub``.
|
||||
3. runs the plan / interjections / vqa modules across the dataset,
|
||||
4. uploads the annotated dataset back to ``--repo_id``.
|
||||
|
||||
``--repo_id`` is both the download source and, with ``--push_to_hub=true``,
|
||||
the upload destination — the job annotates the dataset in place.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotation/run_hf_job.py
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` below to point at your own dataset / target hub repo.
|
||||
Adjust ``CMD`` below to point at your own dataset.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -36,7 +53,9 @@ CMD = (
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=imstevenpmwork/super_poulain_draft "
|
||||
# The dataset to annotate; also the push destination (annotate in place).
|
||||
"--repo_id=<your-org>/<your-dataset> "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||
"--vlm.parallel_servers=2 "
|
||||
@@ -50,11 +69,10 @@ CMD = (
|
||||
"--executor.episode_parallelism=32 "
|
||||
"--vlm.chat_template_kwargs='{enable_thinking: false}' "
|
||||
"--vlm.camera_key=observation.images.wrist "
|
||||
"--module_1.frames_per_second=1.0 "
|
||||
"--module_1.use_video_url=true "
|
||||
"--module_1.use_video_url_fps=1.0 "
|
||||
"--module_3.K=1 --module_3.vqa_emission_hz=0.2 "
|
||||
"--push_to_hub=pepijn223/super_poulain_qwen36moe-3"
|
||||
"--plan.frames_per_second=1.0 "
|
||||
"--plan.use_video_url=true "
|
||||
"--plan.use_video_url_fps=1.0 "
|
||||
"--vqa.K=1 --vqa.vqa_emission_hz=0.2"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
+1
-1
@@ -205,7 +205,7 @@ peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
# on Linux, with a transformers fallback elsewhere; openai is the default
|
||||
# backend and talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted endpoints). Distributed execution is
|
||||
# delegated to Hugging Face Jobs (see examples/annotation/run_hf_job.py).
|
||||
# delegated to Hugging Face Jobs (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
|
||||
@@ -19,9 +19,9 @@
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (Module 1) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (Module 2) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (Module 3) — event-style VQA pairs
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
|
||||
@@ -22,12 +22,12 @@ from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module1Config:
|
||||
"""Module 1: plan + subtasks + memory + task augmentation.
|
||||
class PlanConfig:
|
||||
"""``plan`` module: plan + subtasks + memory + task augmentation.
|
||||
|
||||
Module 1 attaches the whole episode as one Qwen-VL video block;
|
||||
``max_video_frames`` only caps the frames packed in (a model-capacity
|
||||
bound, not an annotation-logic knob).
|
||||
The ``plan`` module attaches the whole episode as one Qwen-VL video
|
||||
block; ``max_video_frames`` only caps the frames packed in (a
|
||||
model-capacity bound, not an annotation-logic knob).
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
@@ -39,8 +39,8 @@ class Module1Config:
|
||||
# When to derive the task from the video instead of using
|
||||
# ``record.episode_task``: ``off``, ``if_short`` (short / placeholder /
|
||||
# missing canonical task), or ``always``. The derived task replaces the
|
||||
# canonical one for every Module-1 prompt; ``meta/tasks.parquet`` is
|
||||
# never modified.
|
||||
# canonical one for every ``plan``-module prompt; ``meta/tasks.parquet``
|
||||
# is never modified.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
@@ -51,21 +51,22 @@ class Module1Config:
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# When True (and backend supports it, e.g. ``openai``), Module 1 sends a
|
||||
# ``video_url`` block pointing at a per-episode mp4 subclip and lets the
|
||||
# server sample frames at ``use_video_url_fps``.
|
||||
# When True (and backend supports it, e.g. ``openai``), the ``plan``
|
||||
# module sends a ``video_url`` block pointing at a per-episode mp4
|
||||
# subclip and lets the server sample frames at ``use_video_url_fps``.
|
||||
use_video_url: bool = False
|
||||
use_video_url_fps: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module2Config:
|
||||
"""Module 2: interjections + paired speech."""
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each interjection emits a paired ``(interjection, speech)`` event row
|
||||
# and triggers a ``plan`` refresh at the same timestamp via Module 1.
|
||||
# and triggers a ``plan`` refresh at the same timestamp via the
|
||||
# ``plan`` module.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
@@ -77,8 +78,8 @@ class Module2Config:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module3Config:
|
||||
"""Module 3: general VQA."""
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
@@ -161,6 +162,8 @@ class AnnotationPipelineConfig:
|
||||
revisions of the same dataset live in separate copies.
|
||||
"""
|
||||
|
||||
# Hub dataset id. Used as the download source when ``root`` is unset,
|
||||
# and as the destination repo when ``push_to_hub`` is enabled.
|
||||
repo_id: str | None = None
|
||||
root: Path | None = None
|
||||
|
||||
@@ -169,9 +172,9 @@ class AnnotationPipelineConfig:
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
module_1: Module1Config = field(default_factory=Module1Config)
|
||||
module_2: Module2Config = field(default_factory=Module2Config)
|
||||
module_3: Module3Config = field(default_factory=Module3Config)
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
@@ -179,8 +182,9 @@ class AnnotationPipelineConfig:
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Upload the annotated dataset to the Hugging Face Hub when set.
|
||||
push_to_hub: str | None = None
|
||||
# When True, upload the annotated dataset back to ``repo_id`` on the
|
||||
# Hugging Face Hub. ``repo_id`` must be set for this to take effect.
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
|
||||
@@ -17,19 +17,20 @@
|
||||
|
||||
The executor plans **six phases** in the dependency order from the plan:
|
||||
|
||||
phase 1: Module 1 (plan + subtasks + memory)
|
||||
phase 2: Module 2 (interjections + speech)
|
||||
phase 3: Module 1 plan-update pass — re-runs plan emission at every
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: Module 3 (VQA)
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why Module 1 must be re-entered after Module 2 — to refresh
|
||||
``plan`` rows at interjection timestamps.
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotation/run_hf_job.py``); the runner inside the job
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
@@ -38,6 +39,8 @@ Episode-level concurrency is controlled by
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -71,7 +74,7 @@ class PipelineRunSummary:
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all four phases over a dataset root in-process.
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
@@ -80,9 +83,9 @@ class Executor:
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
module_1: Any # PlanSubtasksMemoryModule
|
||||
module_2: Any # InterjectionsAndSpeechModule
|
||||
module_3: Any # GeneralVqaModule
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
@@ -99,16 +102,16 @@ class Executor:
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: Module 1 (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1))
|
||||
# Phase 2: Module 2 (interjections + speech). Module 2 reads
|
||||
# Module 1's subtask rows from the same staging tree to ground
|
||||
# the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2))
|
||||
# Phase 3: Module 1 plan-update pass at interjection timestamps.
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: Module 3 (VQA)
|
||||
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
@@ -135,50 +138,37 @@ class Executor:
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = json.loads(info_path.read_text())
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
features = info.get("features")
|
||||
if not isinstance(features, dict):
|
||||
features = {}
|
||||
merged_features = {**features, **language_feature_info()}
|
||||
if merged_features != features:
|
||||
info["features"] = merged_features
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.get("tools")
|
||||
if not isinstance(existing, list):
|
||||
existing = []
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
merged = list(existing)
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
merged.append(SAY_TOOL_SCHEMA)
|
||||
if merged != existing:
|
||||
info["tools"] = merged
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
# Atomic replace — info.json is load-bearing for dataset
|
||||
# metadata, so a crash mid-write would brick the dataset.
|
||||
tmp_info = info_path.with_suffix(info_path.suffix + ".tmp")
|
||||
tmp_info.write_text(json.dumps(info, indent=2))
|
||||
tmp_info.replace(info_path)
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in merged]}",
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
@@ -189,9 +179,6 @@ class Executor:
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
import time as _time # noqa: PLC0415
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
|
||||
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
@@ -201,14 +188,14 @@ class Executor:
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = _time.time()
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = _time.time()
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, _time.time() - ep_start
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
@@ -230,38 +217,39 @@ class Executor:
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = _time.time() - t0
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
|
||||
therefore calls back into Module 1 with the interjection timestamps so
|
||||
Module 1's existing prompt path is reused.
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.module_1.enabled or not self.module_2.enabled:
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(
|
||||
name="module_1_plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||
name="plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||
)
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("module_2") if row.get("style") == "interjection"
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.module_1.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="module_1_plan_update",
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ class FrameProvider(Protocol):
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(Module 1, Module 2) keep working unchanged.
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
@@ -100,10 +100,11 @@ def null_provider() -> FrameProvider:
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for Module 1 (subtask
|
||||
decomposition) and Module 2 (interjection scenarios) — those prompts care
|
||||
about *what is happening*, not which angle. Module 3 (VQA) instead
|
||||
iterates over every camera in :attr:`camera_keys` so each frame's
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
@@ -112,7 +113,7 @@ class VideoFrameProvider:
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
@@ -122,7 +123,7 @@ class VideoFrameProvider:
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs Module 1/2/3 phases under a ThreadPoolExecutor (see
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
@@ -131,11 +132,10 @@ class VideoFrameProvider:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# ``camera_keys`` covers both image- and video-stored cameras
|
||||
# (``video_keys`` is video-only). Some datasets declare cameras with
|
||||
# ``dtype=image``, which would otherwise look empty here and silently
|
||||
# disable Module 3 even though the videos are there.
|
||||
keys = list(getattr(self._meta, "camera_keys", None) or self._meta.video_keys or [])
|
||||
# ``camera_keys`` covers both image- and video-stored cameras and is
|
||||
# always defined on the metadata (``[]`` in the worst case), so it is
|
||||
# the single source we need here.
|
||||
keys = list(self._meta.camera_keys)
|
||||
# Last-resort fallback: if metadata didn't surface anything but the
|
||||
# caller explicitly named a camera (``--vlm.camera_key=...``), trust
|
||||
# them — the key is by definition known to exist on the dataset.
|
||||
@@ -275,10 +275,10 @@ class VideoFrameProvider:
|
||||
try:
|
||||
return _decode_pyav_direct(video_path, shifted, self.tolerance_s)
|
||||
except Exception as exc:
|
||||
# Log loudly the first time decoding fails so silent
|
||||
# Module-3-no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
# Log loudly the first time decoding fails so a silent
|
||||
# vqa-module no-op (every prompt skipped because frames_at
|
||||
# returned []) is debuggable from the job log instead of
|
||||
# post-hoc parquet inspection. Subsequent failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = getattr(self, "_warned_decode_fail", False)
|
||||
if not already_warned:
|
||||
|
||||
@@ -13,10 +13,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 3: general VQA at a timed cadence.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
|
||||
emission. For datasets with multiple cameras, every emission tick produces
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
@@ -26,7 +28,7 @@ per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's Module 3 table): bbox, keypoint,
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
@@ -35,12 +37,13 @@ retry inside :meth:`VlmClient.generate_json`.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module3Config
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
@@ -89,7 +92,7 @@ class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module3Config
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@@ -99,7 +102,7 @@ class GeneralVqaModule:
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("module_3", [])
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
@@ -111,17 +114,15 @@ class GeneralVqaModule:
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not getattr(self, "_warned_no_camera", False):
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"Module 3 (VQA) found no cameras on the frame provider — "
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("module_3", [])
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
@@ -140,13 +141,13 @@ class GeneralVqaModule:
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("module_3", [])
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results):
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
@@ -171,10 +172,10 @@ class GeneralVqaModule:
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("module_3", rows)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras Module 3 should iterate per emission tick.
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
@@ -214,17 +215,6 @@ class GeneralVqaModule:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
def _generate_one(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
messages = self._build_messages(record, question_type, frame_timestamp, camera_key)
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
return self._postprocess(result)
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 2: interjections + paired speech (EVENT styles + speech atoms).
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
@@ -26,8 +26,8 @@ Two sub-passes:
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
Module 1's :meth:`run_plan_updates` reuses Module 2's interjection
|
||||
timestamps to refresh the ``plan`` row at the same instant.
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -37,7 +37,7 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module2Config
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
@@ -51,7 +51,7 @@ class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module2Config
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@@ -66,13 +66,13 @@ class InterjectionsAndSpeechModule:
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull Module 1's subtask spans for this episode so the
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. Module 1 ran first.
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("module_1"), episode_end_t=episode_end_t)
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("module_2", rows)
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -22,7 +22,7 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module1Config
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
@@ -46,13 +46,13 @@ class PlanSubtasksMemoryModule:
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
Module 2 completes).
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module1Config
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
@@ -61,14 +61,14 @@ class PlanSubtasksMemoryModule:
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Resolve the task that drives every other Module-1 prompt. May be
|
||||
# the canonical ``record.episode_task`` (default), or a fresh
|
||||
# Resolve the task that drives every other ``plan``-module prompt.
|
||||
# May be the canonical ``record.episode_task`` (default), or a fresh
|
||||
# description derived from the video when the canonical task is
|
||||
# empty / placeholder / forced-off (see Module1Config.derive_task_*).
|
||||
# empty / placeholder / forced-off (see PlanConfig.derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
|
||||
# PR 1 renderer rotates ``${task}`` deterministically through them
|
||||
# so the policy sees diverse phrasings during training.
|
||||
# message renderer rotates ``${task}`` deterministically through
|
||||
# them so the policy sees diverse phrasings during training.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
if self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
@@ -134,7 +134,7 @@ class PlanSubtasksMemoryModule:
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("module_1", rows)
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
@@ -156,7 +156,7 @@ class PlanSubtasksMemoryModule:
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives Module 1 for this episode.
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
@@ -182,7 +182,7 @@ class PlanSubtasksMemoryModule:
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers (factored out: every Module-1 prompt below follows
|
||||
# VLM call helpers (factored out: every ``plan``-module prompt below follows
|
||||
# the same "build messages → single VLM call → pull a named field"
|
||||
# shape, only differing in field name + post-processing).
|
||||
# ------------------------------------------------------------------
|
||||
@@ -258,7 +258,7 @@ class PlanSubtasksMemoryModule:
|
||||
(the previous version told the model "an interjection happened"
|
||||
without telling it what the user said).
|
||||
"""
|
||||
existing = staging.read("module_1")
|
||||
existing = staging.read("plan")
|
||||
# Pass the episode's last frame timestamp so the final subtask
|
||||
# span is closed (otherwise its ``end`` equals its ``start``,
|
||||
# zero duration, and the "current subtask at refresh_t" lookup
|
||||
@@ -289,7 +289,7 @@ class PlanSubtasksMemoryModule:
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("module_1", new_rows)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
|
||||
@@ -38,6 +38,7 @@ from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@@ -83,8 +84,9 @@ def reconstruct_subtask_spans(
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by Module 1 (plan-update pass) and Module 2 (interjection
|
||||
anchoring), which both need the same span shape.
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
@@ -105,8 +107,9 @@ def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame (event rows must land on an
|
||||
exact frame, see PR 1's "exact event matching" rule).
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
@@ -115,14 +118,17 @@ def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
tasks_path = root / DEFAULT_TASKS_PATH
|
||||
if not tasks_path.exists():
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
table = pq.read_table(tasks_path)
|
||||
cols = {name: table.column(name).to_pylist() for name in table.column_names}
|
||||
if "task_index" in cols and "task" in cols:
|
||||
return dict(zip(cols["task_index"], cols["task"], strict=True))
|
||||
raise ValueError(f"meta/tasks.parquet at {tasks_path} missing 'task_index' or 'task'")
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
|
||||
@@ -36,9 +36,9 @@ from typing import Any
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"module_1",
|
||||
"module_2",
|
||||
"module_3",
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after Modules 1–3 have all written their per-episode artifacts but
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
@@ -218,11 +218,11 @@ class StagingValidator:
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "module_1" and target_col != LANGUAGE_PERSISTENT:
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS:
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
@@ -32,10 +32,20 @@ The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
@@ -212,10 +222,8 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||
# convolution kernels — slower but functional.
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
if _os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
if os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
_torch.backends.cudnn.enabled = False
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
@@ -259,9 +267,7 @@ def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
"for VL models."
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(config.model_id, trust_remote_code=config.trust_remote_code)
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
use_accelerate = _os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
use_accelerate = os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||
@@ -276,7 +282,7 @@ def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
import torch as _torch # noqa: PLC0415 - optional GPU dep, deferred
|
||||
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
@@ -390,8 +396,6 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
|
||||
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
@@ -411,15 +415,6 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
import atexit # noqa: PLC0415
|
||||
import os as _os # noqa: PLC0415
|
||||
import shlex # noqa: PLC0415
|
||||
import signal # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
import sys # noqa: PLC0415
|
||||
import threading # noqa: PLC0415
|
||||
import time # noqa: PLC0415
|
||||
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
@@ -449,7 +444,7 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = _os.environ.copy()
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
@@ -522,8 +517,6 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
@@ -546,14 +539,6 @@ def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
import atexit # noqa: PLC0415
|
||||
import shlex # noqa: PLC0415
|
||||
import signal # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
import sys # noqa: PLC0415
|
||||
import threading # noqa: PLC0415
|
||||
import time # noqa: PLC0415
|
||||
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
@@ -695,8 +680,6 @@ def _to_openai_messages(
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
import base64 # noqa: PLC0415
|
||||
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
@@ -704,9 +687,6 @@ def _file_to_data_url(path: str) -> str:
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
import base64 # noqa: PLC0415
|
||||
import io # noqa: PLC0415
|
||||
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
@@ -29,7 +29,7 @@ For every episode the writer:
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct (PR 1) for every speech atom. The tool *schema* (the description
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
@@ -69,7 +69,7 @@ from .staging import EpisodeStaging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants moved to lerobot.datasets.language in PR 1 — single
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
@@ -309,8 +309,8 @@ class LanguageColumnsWriter:
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — see PR 1's
|
||||
# `tests/datasets/test_language.py` which exercises the same flow.
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ Example:
|
||||
--root=/path/to/dataset \\
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
For distributed runs, see ``examples/annotation/run_hf_job.py``.
|
||||
For distributed runs, see ``examples/annotations/run_hf_job.py``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -65,27 +65,27 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
|
||||
vlm = make_vlm_client(cfg.vlm)
|
||||
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key)
|
||||
# Surface the resolved cameras up front so silent Module-3-no-op
|
||||
# regressions are obvious in job output rather than discovered post-hoc
|
||||
# by counting parquet rows.
|
||||
# Surface the resolved cameras up front so a silent vqa-module no-op
|
||||
# is obvious in job output rather than discovered post-hoc by counting
|
||||
# parquet rows.
|
||||
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
|
||||
logger.info(
|
||||
"annotate: frame_provider default camera=%r, all cameras=%s",
|
||||
getattr(frame_provider, "camera_key", None),
|
||||
cam_keys,
|
||||
)
|
||||
if cfg.module_3.enabled and not cam_keys:
|
||||
if cfg.vqa.enabled and not cam_keys:
|
||||
logger.warning(
|
||||
"annotate: Module 3 (VQA) is enabled but no cameras were "
|
||||
"resolved — Module 3 will produce zero VQA rows. Check "
|
||||
"annotate: the vqa module is enabled but no cameras were "
|
||||
"resolved — it will produce zero VQA rows. Check "
|
||||
"meta/info.json for observation.images.* features, or pass "
|
||||
"--vlm.camera_key=<key> to seed the cameras list."
|
||||
)
|
||||
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1, frame_provider=frame_provider)
|
||||
module_2 = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.module_2, seed=cfg.seed, frame_provider=frame_provider
|
||||
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
|
||||
interjections = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
|
||||
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
@@ -93,9 +93,9 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
module_1=module_1,
|
||||
module_2=module_2,
|
||||
module_3=module_3,
|
||||
plan=plan,
|
||||
interjections=interjections,
|
||||
vqa=vqa,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
@@ -113,14 +113,16 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
logger.warning(w)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
if cfg.repo_id is None:
|
||||
raise ValueError("--push_to_hub requires --repo_id (the dataset repo to push to).")
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
|
||||
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Upload the annotated dataset directory to the Hugging Face Hub."""
|
||||
"""Upload the annotated dataset directory back to ``cfg.repo_id`` on the Hub."""
|
||||
from huggingface_hub import HfApi # noqa: PLC0415
|
||||
|
||||
repo_id = cfg.push_to_hub
|
||||
repo_id = cfg.repo_id
|
||||
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||
api = HfApi()
|
||||
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||
|
||||
@@ -15,85 +15,24 @@
|
||||
# limitations under the License.
|
||||
"""Shared fixtures for annotation-pipeline tests.
|
||||
|
||||
Builds a minimal LeRobot-shaped dataset on disk so writer/validator tests
|
||||
can exercise real parquet reads and writes without needing a checked-in
|
||||
LFS dataset.
|
||||
The on-disk dataset builder lives with the other dataset factories in
|
||||
``tests/fixtures/dataset_factories.py`` (:func:`build_annotation_dataset`);
|
||||
these fixtures only wire it into pytest.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_episode_table(
|
||||
episode_index: int,
|
||||
num_frames: int,
|
||||
*,
|
||||
fps: int = 10,
|
||||
task_index: int = 0,
|
||||
) -> pa.Table:
|
||||
timestamps = [round(i / fps, 6) for i in range(num_frames)]
|
||||
frame_indices = list(range(num_frames))
|
||||
return pa.Table.from_pydict(
|
||||
{
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": frame_indices,
|
||||
"timestamp": timestamps,
|
||||
"task_index": [task_index] * num_frames,
|
||||
"subtask_index": [0] * num_frames, # legacy column the writer must drop
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_dataset(root: Path, episode_specs: list[tuple[int, int, str]], *, fps: int = 10) -> Path:
|
||||
"""Create a fixture dataset under ``root``.
|
||||
|
||||
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
|
||||
Each episode goes into its own ``data/chunk-000/file-{ep:03d}.parquet``
|
||||
so the writer's per-shard rewrite path is exercised.
|
||||
"""
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
tasks = {}
|
||||
for episode_index, num_frames, task_text in episode_specs:
|
||||
task_index = len(tasks)
|
||||
if task_text not in tasks.values():
|
||||
tasks[task_index] = task_text
|
||||
else:
|
||||
task_index = next(k for k, v in tasks.items() if v == task_text)
|
||||
table = _make_episode_table(episode_index, num_frames, fps=fps, task_index=task_index)
|
||||
path = data_dir / f"file-{episode_index:03d}.parquet"
|
||||
pq.write_table(table, path)
|
||||
|
||||
meta_dir = root / "meta"
|
||||
meta_dir.mkdir(parents=True, exist_ok=True)
|
||||
tasks_table = pa.Table.from_pydict(
|
||||
{
|
||||
"task_index": list(tasks.keys()),
|
||||
"task": list(tasks.values()),
|
||||
}
|
||||
)
|
||||
pq.write_table(tasks_table, meta_dir / "tasks.parquet")
|
||||
|
||||
info = {
|
||||
"codebase_version": "v3.1",
|
||||
"fps": fps,
|
||||
"total_episodes": len(episode_specs),
|
||||
}
|
||||
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
|
||||
|
||||
return root
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
|
||||
return _build_dataset(
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds",
|
||||
episode_specs=[
|
||||
(0, 12, "Could you tidy the kitchen please?"),
|
||||
@@ -105,7 +44,7 @@ def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||
|
||||
@pytest.fixture
|
||||
def single_episode_root(tmp_path: Path) -> Path:
|
||||
return _build_dataset(
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds_one",
|
||||
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
|
||||
fps=10,
|
||||
|
||||
@@ -15,22 +15,19 @@
|
||||
# limitations under the License.
|
||||
"""Opt-in E2E smoke run for ``make annotation-e2e``.
|
||||
|
||||
Builds the same fixture used by the pytest suite, runs the full
|
||||
annotation pipeline against it with a stub VLM, and prints a short report.
|
||||
This is intentionally not a pytest test — it exercises the CLI plumbing
|
||||
without depending on conftest.py fixtures.
|
||||
Builds the shared annotation fixture (:func:`build_annotation_dataset`),
|
||||
runs the full annotation pipeline against it with a stub VLM, and prints a
|
||||
short report. This is intentionally not a pytest test — it exercises the
|
||||
CLI plumbing — but it reuses the same on-disk dataset builder as the pytest
|
||||
fixtures so there is no duplicated fixture code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
@@ -41,31 +38,7 @@ from lerobot.annotations.steerable_pipeline.modules import (
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
|
||||
|
||||
def _build_dataset(root: Path) -> Path:
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
n = 30
|
||||
timestamps = [round(i / 10, 6) for i in range(n)]
|
||||
table = pa.Table.from_pydict(
|
||||
{
|
||||
"episode_index": [0] * n,
|
||||
"frame_index": list(range(n)),
|
||||
"timestamp": timestamps,
|
||||
"task_index": [0] * n,
|
||||
"subtask_index": [0] * n,
|
||||
}
|
||||
)
|
||||
pq.write_table(table, data_dir / "file-000.parquet")
|
||||
meta = root / "meta"
|
||||
meta.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(
|
||||
pa.Table.from_pydict({"task_index": [0], "task": ["Pour water into the cup."]}),
|
||||
meta / "tasks.parquet",
|
||||
)
|
||||
(meta / "info.json").write_text(json.dumps({"codebase_version": "v3.1", "fps": 10}))
|
||||
return root
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
|
||||
def _stub_responder(messages):
|
||||
@@ -102,14 +75,18 @@ def _stub_responder(messages):
|
||||
|
||||
def main() -> int:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = _build_dataset(Path(tmp) / "ds")
|
||||
root = build_annotation_dataset(
|
||||
Path(tmp) / "ds",
|
||||
episode_specs=[(0, 30, "Pour water into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
vlm = StubVlmClient(responder=_stub_responder)
|
||||
cfg = AnnotationPipelineConfig()
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1),
|
||||
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed),
|
||||
module_3=GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed),
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
@@ -23,9 +23,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
Module1Config,
|
||||
Module2Config,
|
||||
Module3Config,
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
@@ -84,11 +84,11 @@ def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path:
|
||||
"Update the memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=Module1Config())
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig())
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_1")
|
||||
rows = staging.read("plan")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert {"subtask", "plan", "memory"}.issubset(styles)
|
||||
@@ -108,12 +108,12 @@ def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: P
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=Module2Config(max_interjections_per_episode=0),
|
||||
config=InterjectionsConfig(max_interjections_per_episode=0),
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_2")
|
||||
rows = staging.read("interjections")
|
||||
assert len(rows) == 1
|
||||
only = rows[0]
|
||||
assert only["role"] == "assistant"
|
||||
@@ -151,7 +151,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||
config=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||
seed=7,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
@@ -161,7 +161,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
# production executor guarantees Module 1 ran first).
|
||||
boundary_ts = float(record.frame_timestamps[len(record.frame_timestamps) // 2])
|
||||
staging.write(
|
||||
"module_1",
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -180,7 +180,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
],
|
||||
)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_2")
|
||||
rows = staging.read("interjections")
|
||||
|
||||
interjections = [r for r in rows if r["style"] == "interjection"]
|
||||
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
|
||||
@@ -198,14 +198,14 @@ def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=3),
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=3),
|
||||
seed=1,
|
||||
frame_provider=_StubFrameProvider(cameras=("observation.images.top", "observation.images.wrist")),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_3")
|
||||
rows = staging.read("vqa")
|
||||
# every vqa row must carry a camera tag and one of the configured cameras
|
||||
for r in rows:
|
||||
assert r["style"] == "vqa"
|
||||
@@ -257,7 +257,7 @@ def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Pa
|
||||
# call is the subtask one — keeps the assertions below focused on
|
||||
# ``_generate_subtasks`` rather than fighting the order of unrelated
|
||||
# text-only Module-1 sub-prompts.
|
||||
config=Module1Config(max_video_frames=5, frames_per_second=10.0, n_task_rephrasings=0),
|
||||
config=PlanConfig(max_video_frames=5, frames_per_second=10.0, n_task_rephrasings=0),
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
@@ -304,7 +304,7 @@ def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path,
|
||||
provider = _StubFrameProvider()
|
||||
module = GeneralVqaModule(
|
||||
vlm=_spy_responder(captured, payload),
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=1),
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=1),
|
||||
seed=0,
|
||||
frame_provider=provider,
|
||||
)
|
||||
@@ -336,14 +336,14 @@ def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
seed=2,
|
||||
frame_provider=_StubFrameProvider(),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_3")
|
||||
rows = staging.read("vqa")
|
||||
for row in rows:
|
||||
if row["role"] == "assistant" and row["style"] == "vqa":
|
||||
decoded = json.loads(row["content"])
|
||||
|
||||
@@ -23,9 +23,9 @@ import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
AnnotationPipelineConfig,
|
||||
Module1Config,
|
||||
Module2Config,
|
||||
Module3Config,
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
@@ -115,15 +115,15 @@ def _build_executor() -> Executor:
|
||||
},
|
||||
)
|
||||
config = AnnotationPipelineConfig(
|
||||
module_1=Module1Config(),
|
||||
module_2=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||
module_3=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||
plan=PlanConfig(),
|
||||
interjections=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||
vqa=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
)
|
||||
return Executor(
|
||||
config=config,
|
||||
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=config.module_1),
|
||||
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=config.module_2, seed=config.seed),
|
||||
module_3=GeneralVqaModule(vlm=vlm, config=config.module_3, seed=config.seed),
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ def _validate(root: Path, staging_dir: Path):
|
||||
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_3",
|
||||
"vqa",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -53,7 +53,7 @@ def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp
|
||||
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_2",
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
# interjection at 0.3s with NO paired speech
|
||||
@@ -74,7 +74,7 @@ def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: P
|
||||
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_1",
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -93,7 +93,7 @@ def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path,
|
||||
],
|
||||
)
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_2",
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
speech_atom(0.4, "Replanning."),
|
||||
@@ -115,11 +115,11 @@ def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path,
|
||||
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_1",
|
||||
"plan",
|
||||
[
|
||||
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("module_1 emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||
assert any("plan emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||
|
||||
@@ -35,17 +35,17 @@ def _stage_episode(
|
||||
staging_dir: Path,
|
||||
episode_index: int,
|
||||
*,
|
||||
module_1: list[dict] | None = None,
|
||||
module_2: list[dict] | None = None,
|
||||
module_3: list[dict] | None = None,
|
||||
plan: list[dict] | None = None,
|
||||
interjections: list[dict] | None = None,
|
||||
vqa: list[dict] | None = None,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, episode_index)
|
||||
if module_1 is not None:
|
||||
staging.write("module_1", module_1)
|
||||
if module_2 is not None:
|
||||
staging.write("module_2", module_2)
|
||||
if module_3 is not None:
|
||||
staging.write("module_3", module_3)
|
||||
if plan is not None:
|
||||
staging.write("plan", plan)
|
||||
if interjections is not None:
|
||||
staging.write("interjections", interjections)
|
||||
if vqa is not None:
|
||||
staging.write("vqa", vqa)
|
||||
|
||||
|
||||
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
@@ -54,7 +54,7 @@ def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path)
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
@@ -94,7 +94,7 @@ def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Pat
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_2=[
|
||||
interjections=[
|
||||
speech_atom(0.0, "Got it."),
|
||||
{
|
||||
"role": "user",
|
||||
@@ -127,7 +127,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
@@ -150,7 +150,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
module_2=[
|
||||
interjections=[
|
||||
speech_atom(0.0, "OK"),
|
||||
{
|
||||
"role": "user",
|
||||
@@ -161,7 +161,7 @@ def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> No
|
||||
},
|
||||
speech_atom(0.2, "Waiting"),
|
||||
],
|
||||
module_3=[
|
||||
vqa=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "where is the cup?",
|
||||
@@ -201,7 +201,7 @@ def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_p
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
@@ -277,7 +277,7 @@ def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path:
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
plan=[
|
||||
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
@@ -316,7 +316,7 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
plan=[
|
||||
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
|
||||
Vendored
+61
@@ -555,3 +555,64 @@ def lerobot_dataset_factory(
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
||||
|
||||
def build_annotation_dataset(
|
||||
root: Path,
|
||||
episode_specs: list[tuple[int, int, str]],
|
||||
*,
|
||||
fps: int = 10,
|
||||
) -> Path:
|
||||
"""Build a minimal LeRobot-shaped dataset on disk for annotation tests.
|
||||
|
||||
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
|
||||
Each episode is written to its own
|
||||
``data/chunk-000/file-{ep:03d}.parquet`` so the writer's per-shard
|
||||
rewrite path is exercised. The dataset carries the minimum
|
||||
``meta/tasks.parquet`` + ``meta/info.json`` the reader / executor need;
|
||||
it has no videos, so the modules fall back to text-only prompts.
|
||||
|
||||
Shared by the annotation-pipeline pytest fixtures (``tests/annotations/
|
||||
conftest.py``) and the opt-in E2E smoke run so the fixture shape lives
|
||||
in exactly one place.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import write_tasks
|
||||
from lerobot.utils.io_utils import write_json
|
||||
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tasks: dict[int, str] = {}
|
||||
for episode_index, num_frames, task_text in episode_specs:
|
||||
if task_text not in tasks.values():
|
||||
tasks[len(tasks)] = task_text
|
||||
task_index = next(k for k, v in tasks.items() if v == task_text)
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": list(range(num_frames)),
|
||||
"timestamp": [round(i / fps, 6) for i in range(num_frames)],
|
||||
"task_index": [task_index] * num_frames,
|
||||
"subtask_index": [0] * num_frames, # legacy column the writer must drop
|
||||
}
|
||||
)
|
||||
frame.to_parquet(data_dir / f"file-{episode_index:03d}.parquet", index=False)
|
||||
|
||||
# Canonical tasks frame: indexed by task string with a ``task_index``
|
||||
# column, matching what ``lerobot.datasets.io_utils.load_tasks`` expects.
|
||||
tasks_df = pd.DataFrame(
|
||||
{"task_index": list(tasks.keys())},
|
||||
index=pd.Index(list(tasks.values()), name="task"),
|
||||
)
|
||||
write_tasks(tasks_df, root)
|
||||
|
||||
write_json(
|
||||
{
|
||||
"codebase_version": "v3.1",
|
||||
"fps": fps,
|
||||
"features": {},
|
||||
"total_episodes": len(episode_specs),
|
||||
},
|
||||
root / "meta" / "info.json",
|
||||
)
|
||||
return root
|
||||
|
||||
Reference in New Issue
Block a user