mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fa3eb9fce3 | |||
| 500c91ba92 |
@@ -65,9 +65,6 @@ repos:
|
||||
name: Format Markdown with Prettier
|
||||
types_or: [markdown, mdx]
|
||||
args: [--prose-wrap=preserve]
|
||||
# Jinja2 model-card templates use a .md extension but contain {% ... %} /
|
||||
# {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops).
|
||||
exclude: ^src/lerobot/templates/.*\.md$
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
|
||||
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -58,7 +58,7 @@ action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
@@ -101,13 +101,11 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
|
||||
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
|
||||
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
@@ -135,7 +133,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -143,7 +140,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -45,8 +45,6 @@
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` watches each episode's video with a vision-language
|
||||
model (VLM) and writes natural-language annotations back into your
|
||||
dataset. It fills the two language columns from the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — straight into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
|
||||
memory, interjections, speech, and visual Q&A that a policy can be
|
||||
trained on.
|
||||
|
||||
## How it fits together
|
||||
|
||||
```text
|
||||
your dataset lerobot-annotate
|
||||
(LeRobot v3.1)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ read episodes │
|
||||
└──────────────────────────┬──────────────────────────┘
|
||||
│
|
||||
┌────────────────────┼────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
|
||||
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
|
||||
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
|
||||
└────────────────────┼─────────────────────┘
|
||||
│ each module stages raw JSONL
|
||||
▼ into .annotate_staging/
|
||||
┌─────────────────┐
|
||||
│ validator │ ◀── checks everything
|
||||
└────────┬────────┘
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ writer │
|
||||
└────────┬────────┘
|
||||
▼
|
||||
data/chunk-*/file-*.parquet
|
||||
(+ meta/info.json tools)
|
||||
```
|
||||
|
||||
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
|
||||
VLM. Each module stages its output to disk, a validator checks it, and a
|
||||
single writer rewrites the dataset shards in place.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
Each module emits a few kinds of annotation ("styles"), routed to one of
|
||||
the two language columns:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | --------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections` |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
### How subtasks are generated
|
||||
|
||||
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
|
||||
it uses a two-step **describe → segment** flow:
|
||||
|
||||
1. **Describe** — the VLM narrates only what it actually sees in the
|
||||
chosen camera (no guessing about the task).
|
||||
2. **Segment** — that description is fed back in, and the VLM splits the
|
||||
episode into consecutive atomic subtasks.
|
||||
|
||||
Both passes see the episode as **timestamped contact sheets** — frames
|
||||
sampled at `frames_per_second` (0.5s by default) and packed into JPEG
|
||||
grids with each frame's time burned into its corner, so the VLM cites
|
||||
exact boundary times directly. This is far cheaper in vision tokens than
|
||||
one image per frame, so the sampling can stay dense; episodes longer than
|
||||
`max_frames_per_prompt` are split into windows at the same density and
|
||||
merged. Both prompts also carry a causal **event-boundary** definition (a
|
||||
new event starts when an object becomes held / is released / reaches a new
|
||||
location / a lid changes state / contents move) to sharpen where cuts land.
|
||||
|
||||
The resulting spans are then stitched into a gap-free, full-episode
|
||||
cover, so **every frame has exactly one active subtask**. See
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
for the production settings (single camera, timestamped contact sheets,
|
||||
auto-windowed subtask generation).
|
||||
|
||||
### Tools
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet. The tool
|
||||
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
|
||||
After every run, the pipeline makes sure the canonical `say` schema is in
|
||||
that list, keeping any tools you declared beforehand.
|
||||
|
||||
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
|
||||
pipeline preserves whatever is already there. That makes the tool visible
|
||||
to the chat template, so the model can learn to _generate_ the call. The
|
||||
runtime layer that actually _executes_ a generated call (the `Tool`
|
||||
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
|
||||
this PR — the [Tools](./tools) doc marks those pieces as
|
||||
not-yet-implemented.
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
|
||||
The repo ships a launcher script you copy and tweak for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
|
||||
that:
|
||||
|
||||
1. installs `lerobot` (from `main`) plus the annotation extras,
|
||||
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
|
||||
drives it over the OpenAI-compatible API,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
with `lerobot-annotate`,
|
||||
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
|
||||
back to `--repo_id` in place if you leave that unset).
|
||||
|
||||
To use a different dataset, model, or hub repo, edit the `CMD` block in
|
||||
the script. Every flag there maps directly to a `lerobot-annotate` flag
|
||||
(run `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Key options
|
||||
|
||||
These are the flags you'll reach for most often. Run
|
||||
`lerobot-annotate --help` for everything else; the defaults are tuned for
|
||||
short manipulation episodes.
|
||||
|
||||
### Dataset in / out
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------- | ------- | ----------------------------------------------------------------------- |
|
||||
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
|
||||
| `--root` | — | Annotate a local dataset directory instead. |
|
||||
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
|
||||
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
|
||||
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
|
||||
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
|
||||
|
||||
### Which modules run
|
||||
|
||||
Every module is on by default and can be toggled independently (set to
|
||||
`false` to skip it, e.g. to iterate on one module at a time):
|
||||
|
||||
| Flag | Default | Turns off |
|
||||
| ------------------------- | ------- | ----------------------------------- |
|
||||
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
|
||||
| `--interjections.enabled` | `true` | interjections + speech atoms |
|
||||
| `--vqa.enabled` | `true` | the VQA pairs |
|
||||
|
||||
### The VLM (`--vlm.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
|
||||
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
|
||||
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
|
||||
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
|
||||
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
|
||||
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
|
||||
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
|
||||
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
|
||||
| `--vlm.temperature` | `0.2` | Sampling temperature. |
|
||||
|
||||
### Subtasks / plan / memory (`--plan.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). |
|
||||
| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. |
|
||||
| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). |
|
||||
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
|
||||
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
|
||||
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
|
||||
| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. |
|
||||
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
|
||||
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
|
||||
|
||||
### Interjections + VQA
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
|
||||
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
|
||||
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
|
||||
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
|
||||
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
|
||||
|
||||
## Contributing new modules
|
||||
|
||||
The pipeline is built to grow, and **contributions are very welcome** —
|
||||
a brand-new module (say, trajectory traces or affordances), a new prompt
|
||||
template, a smarter grounding flow, or quality fixes to the existing
|
||||
`plan` / `interjections` / `vqa` modules.
|
||||
|
||||
Every module lives under
|
||||
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
|
||||
client and the keyframe cache, writes its raw output to the staging
|
||||
tree, and plugs into the executor as its own phase. Got an idea? Open an
|
||||
issue or PR on [the repo](https://github.com/huggingface/lerobot).
|
||||
|
||||
## How recipes consume the output
|
||||
|
||||
The annotations are meant to be read by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)). Typically:
|
||||
|
||||
- low-level / high-level / memory-update branches read
|
||||
`subtask` / `plan` / `memory` from `language_persistent`.
|
||||
- an interjection-response branch reads `interjection` events plus the
|
||||
paired speech atom (merged into one assistant turn via `tool_calls_from`)
|
||||
and the matching `plan` refresh at the same timestamp.
|
||||
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
|
||||
`language_events`.
|
||||
|
||||
## Why state and events are split
|
||||
|
||||
Two ideas shape the design:
|
||||
|
||||
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
|
||||
`plan`, `memory`) apply to the whole episode and answer "what's true
|
||||
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
|
||||
the one frame whose timestamp matches. Timestamps are copied straight
|
||||
from the source parquet — never recomputed in floating point.
|
||||
2. **One VLM pass.** All three modules share a single VLM client (the
|
||||
OpenAI-compatible client talking to the job's vLLM server), so you pay
|
||||
for one model load per dataset, not three.
|
||||
|
||||
## Re-running a single module
|
||||
|
||||
Each module stages its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
|
||||
prompt iteration cheap: re-running one module overwrites only its own
|
||||
JSONL, then the writer recomposes the final parquet. Disable modules you
|
||||
don't want with `--plan.enabled=false` (and likewise
|
||||
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
|
||||
|
||||
## What the validator checks
|
||||
|
||||
Before the writer runs, `StagingValidator` confirms:
|
||||
|
||||
- every event row lands exactly on a real frame timestamp;
|
||||
- no speech / interjection pairs are left orphaned;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (a warning, not an error);
|
||||
- each VQA assistant `content` is valid JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row goes to the column chosen by `column_for_style(style)`.
|
||||
|
||||
Any error aborts the writer. Pass `--skip_validation=true` to override
|
||||
while debugging.
|
||||
|
||||
## Where each module's ideas come from
|
||||
|
||||
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
for atom granularity ("pick up one piece of lettuce", "place bowl to
|
||||
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
|
||||
for "how, not what" detail.
|
||||
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
|
||||
keep only the minimal relevant information — preserve outcomes, drop
|
||||
specific attributes.
|
||||
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom
|
||||
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
|
||||
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
|
||||
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies
|
||||
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
|
||||
grounding. Pi0.7 also grounds answers across abstraction levels.
|
||||
|
||||
When improving a module, tweak its prompt template in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
|
||||
rewriting from scratch.
|
||||
|
||||
## Roughly how much it costs
|
||||
|
||||
Per episode, the pipeline makes about `max_steps` plan calls,
|
||||
`max_interjections_per_episode` interjection calls, and
|
||||
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
|
||||
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
|
||||
~50 VLM calls.
|
||||
|
||||
Storage stays small: `language_persistent` is at most tens of KB per
|
||||
episode (parquet dictionary-encodes the one entry that repeats across
|
||||
frames), and `language_events` is empty on most frames — its size scales
|
||||
with the number of emissions, not `num_frames × num_emissions`.
|
||||
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
|
||||
|
||||
Spawns one single-GPU ``h200`` job that:
|
||||
|
||||
1. installs ``lerobot`` from ``main`` plus the annotation extras,
|
||||
2. boots one vllm server with Qwen3.6-27B (dense VLM),
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (each episode generates its own subtasks +
|
||||
memory),
|
||||
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
|
||||
run. For larger datasets, scale to ``h200x4`` and raise
|
||||
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
|
||||
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-27B "
|
||||
"--vlm.num_gpus=1 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
# Qwen3.6 ships with thinking on; annotation wants plain JSON answers.
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}'"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
+12
-32
@@ -115,8 +115,8 @@ dataset = [
|
||||
]
|
||||
training = [
|
||||
"lerobot[dataset]",
|
||||
"wandb>=0.24.0,<0.28.0",
|
||||
"lerobot[accelerate-dep]",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
]
|
||||
hardware = [
|
||||
"lerobot[pynput-dep]",
|
||||
@@ -142,8 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
|
||||
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
@@ -178,12 +177,7 @@ unitree_g1 = [
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
|
||||
reachy2 = [
|
||||
"reachy2_sdk>=1.0.15,<1.1.0",
|
||||
"grpcio<=1.73.1",
|
||||
"protobuf<=6.32.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
@@ -205,7 +199,7 @@ wallx = [
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
@@ -220,41 +214,27 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
recap = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
|
||||
# which talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
|
||||
# (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
|
||||
# uv's single unified lock would then cap ``torch`` for every extra
|
||||
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
|
||||
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
|
||||
# install it locally only if you run your own ``vllm serve``.
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -317,6 +297,7 @@ all = [
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[recap]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
@@ -338,7 +319,6 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
@@ -357,7 +337,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
]
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: subtasks + plan + memory + task augmentation."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# Derive the task from video instead of episode_task: off / if_short / always.
|
||||
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# --- Frame input: timestamped contact sheets (always on) ---------------
|
||||
# The subtask describe/segment passes ALWAYS render the episode as
|
||||
# macrodata/refiner-style contact sheets: sampled frames packed into JPEG
|
||||
# grids with each frame's timestamp burned into its corner, so the VLM
|
||||
# cites the exact source time of a boundary directly. This is far cheaper
|
||||
# in vision tokens than one image per frame (≈2× faster subtask generation
|
||||
# in practice), which is why the sampling is dense by default.
|
||||
#
|
||||
# ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s.
|
||||
frames_per_second: float = 2.0
|
||||
# Frame budget per VLM call (= columns × rows × sheets). When a whole
|
||||
# episode sampled at ``frames_per_second`` exceeds this, the episode is
|
||||
# AUTOMATICALLY split into consecutive windows of
|
||||
# ``max_frames_per_prompt`` frames each (one describe→segment call per
|
||||
# window, still at the full ``frames_per_second`` density), and the
|
||||
# per-window spans are merged + stitched into one contiguous cover. So an
|
||||
# episode of any length is always covered at the full sampling density.
|
||||
max_frames_per_prompt: int = 60
|
||||
contact_sheet_columns: int = 5
|
||||
contact_sheet_frames_per_sheet: int = 20
|
||||
contact_sheet_frame_width: int = 224
|
||||
contact_sheet_quality: int = 84
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# Narrate-only grounding pass before segmenting — best defense against subtasks
|
||||
# invented from the task text (+1 VLM call/episode).
|
||||
subtask_describe_first: bool = True
|
||||
|
||||
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
|
||||
emit_plan: bool = True
|
||||
|
||||
# Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only.
|
||||
# Symmetric counterpart of ``emit_plan``.
|
||||
emit_memory: bool = True
|
||||
|
||||
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
|
||||
|
||||
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
|
||||
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskAugAxesConfig:
|
||||
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
|
||||
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
|
||||
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
|
||||
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
synonym_paraphrase: int = 3
|
||||
omit_arm: int = 3
|
||||
omit_orientation: int = 2
|
||||
omit_grasp_method: int = 2
|
||||
combined_omissions: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
|
||||
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
# True: ground VQA only on --vlm.camera_key (default: every camera).
|
||||
restrict_to_default_camera: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
|
||||
# auto_serve=True); ``stub`` is for tests.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-27B"
|
||||
|
||||
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# Spawn a server if none answers api_base; False = fail fast on a remote.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command; ``{port}`` substituted per replica.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
|
||||
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
|
||||
max_model_len: int | None = None
|
||||
|
||||
# Camera for keyframes; None → first ``observation.images.*`` key.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
|
||||
|
||||
# Episodes processed concurrently per phase; main knob for saturating the servers.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||
|
||||
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
|
||||
# is on and ``new_repo_id`` unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
|
||||
new_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/``.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend forwarded to ``decode_video_frames``. None →
|
||||
# library default (torchcodec when available, else PyAV). Or pin
|
||||
# ``"torchcodec"`` / ``"pyav"`` explicitly.
|
||||
video_backend: str | None = None
|
||||
|
||||
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
@@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor runs **six phases** in dependency order:
|
||||
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
@@ -1,481 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import VideoEncoderConfig
|
||||
from lerobot.datasets.video_utils import decode_video_frames, reencode_video
|
||||
|
||||
from .reader import EpisodeRecord, snap_to_frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend forwarded to
|
||||
# :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None``
|
||||
# uses the library default (torchcodec when available, else PyAV).
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
# Serializes decode_video_frames calls: torchcodec hands out one
|
||||
# ``VideoDecoder`` per file from a process-wide cache, and the decoder
|
||||
# is not safe to drive from multiple threads at once.
|
||||
_decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# Only ``video_keys`` are decodable here: the clip/decode paths read
|
||||
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
|
||||
# only for video-stored cameras. Image-stored cameras (also in
|
||||
# ``camera_keys``) would KeyError, so restrict the list — and the
|
||||
# default — to video keys.
|
||||
keys = list(self._meta.video_keys)
|
||||
# Last-resort fallback: if metadata didn't surface any video keys but
|
||||
# the caller explicitly named a camera (``--vlm.camera_key=...``),
|
||||
# trust them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
# Snap each request to the nearest real frame timestamp: callers
|
||||
# sample uniform grids whose points land mid-frame, and
|
||||
# ``decode_video_frames`` rejects queries farther than
|
||||
# ``tolerance_s`` from a decodable frame. Snapping also dedupes
|
||||
# repeat queries through the cache.
|
||||
if record.frame_timestamps:
|
||||
timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps]
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks or extraction
|
||||
failed. Skips re-extract when the cached clip already exists.
|
||||
Re-encodes to H.264 via
|
||||
:func:`lerobot.datasets.video_utils.reencode_video` so the resulting
|
||||
mp4 is decodable by every downstream video processor — stream-copy
|
||||
would inherit the source codec (often AV1 in modern LeRobot
|
||||
datasets), which vllm's libav build cannot decode.
|
||||
"""
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast")
|
||||
try:
|
||||
reencode_video(
|
||||
src,
|
||||
out_path,
|
||||
camera_encoder=encoder,
|
||||
overwrite=True,
|
||||
start_time_s=from_timestamp,
|
||||
end_time_s=to_timestamp,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True
|
||||
)
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec when available, PyAV otherwise; ``video_backend`` pins
|
||||
one explicitly). Returns one frame per requested timestamp, or ``[]``
|
||||
if decoding failed — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
try:
|
||||
# The module phases decode under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``) but torchcodec's cached
|
||||
# per-file decoder is single-threaded, so serialize decodes on a
|
||||
# dedicated lock. Frame extraction is a small fraction of episode
|
||||
# wall time (VLM calls dominate), so the contention is cheap.
|
||||
with self._decode_lock:
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as exc:
|
||||
# Log loudly the first time so a silent vqa-module no-op (every
|
||||
# prompt skipped because frames_at returned []) is debuggable from
|
||||
# the job log instead of post-hoc parquet inspection. Subsequent
|
||||
# failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = self._warned_decode_fail
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
self.video_backend,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
|
||||
|
||||
def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image:
|
||||
"""Burn ``timestamp`` (seconds) into the top-left corner of ``image``.
|
||||
|
||||
A solid black badge with white text, so a VLM reading a contact sheet can
|
||||
cite the exact source time of each tile (e.g. ``012.50s``) directly,
|
||||
instead of the caller having to map tile position back to time. Mirrors
|
||||
the macrodata/refiner contact-sheet convention.
|
||||
"""
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
result = image.copy()
|
||||
draw = ImageDraw.Draw(result)
|
||||
font = ImageFont.load_default()
|
||||
label = f"{timestamp:06.2f}s"
|
||||
left, top, right, bottom = draw.textbbox((0, 0), label, font=font)
|
||||
text_w, text_h = right - left, bottom - top
|
||||
pad = max(3, round(min(image.width, image.height) * 0.018))
|
||||
draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0))
|
||||
draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font)
|
||||
return result
|
||||
|
||||
|
||||
def to_contact_sheet_blocks(
|
||||
frames: Sequence[Any],
|
||||
timestamps: Sequence[float],
|
||||
*,
|
||||
columns: int = 5,
|
||||
frames_per_sheet: int = 20,
|
||||
frame_width: int = 224,
|
||||
quality: int = 84,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Pack decoded frames into timestamped JPEG contact-sheet image blocks.
|
||||
|
||||
Each frame is resized to ``frame_width`` wide, stamped with its
|
||||
episode-relative timestamp, and tiled row-major into grids of
|
||||
``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}``
|
||||
block is returned per grid; many frames collapse into a few images, so a
|
||||
long episode's temporal coverage stays dense at a fraction of the vision
|
||||
tokens N separate frames would cost. ``frames`` and ``timestamps`` must be
|
||||
aligned and equal length. Returns ``[]`` for empty input.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
columns = max(1, columns)
|
||||
frames_per_sheet = max(1, frames_per_sheet)
|
||||
rows_per_sheet = math.ceil(frames_per_sheet / columns)
|
||||
|
||||
tiles: list[PIL.Image.Image] = []
|
||||
for ts, frame in zip(timestamps, frames, strict=False):
|
||||
img = _frame_to_pil(frame)
|
||||
if not isinstance(img, PIL.Image.Image):
|
||||
continue
|
||||
img = img.convert("RGB")
|
||||
if img.width != frame_width:
|
||||
height = max(1, round(img.height * frame_width / img.width))
|
||||
img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR)
|
||||
tiles.append(_draw_timestamp_badge(img, float(ts)))
|
||||
if not tiles:
|
||||
return []
|
||||
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for start in range(0, len(tiles), frames_per_sheet):
|
||||
chunk = tiles[start : start + frames_per_sheet]
|
||||
cell_w = max(tile.width for tile in chunk)
|
||||
cell_h = max(tile.height for tile in chunk)
|
||||
sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0))
|
||||
for i, tile in enumerate(chunk):
|
||||
x = (i % columns) * cell_w
|
||||
y = (i // columns) * cell_h
|
||||
sheet.paste(tile, (x, y))
|
||||
# JPEG round-trip at ``quality`` to match the refiner convention and
|
||||
# shrink the wire payload; vision-token count is set by resolution, so
|
||||
# the real saving is the grid packing, not the codec.
|
||||
buf = io.BytesIO()
|
||||
sheet.save(buf, format="JPEG", quality=quality)
|
||||
buf.seek(0)
|
||||
blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")})
|
||||
return blocks
|
||||
@@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
_warned_no_camera: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not self._warned_no_camera:
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
|
||||
When ``config.restrict_to_default_camera`` is set, VQA grounds on
|
||||
only the provider's default camera (the single ``--vlm.camera_key``
|
||||
stream), matching the plan / interjection modules so the whole
|
||||
pipeline focuses on one view.
|
||||
"""
|
||||
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
if getattr(self.config, "restrict_to_default_camera", False):
|
||||
default = getattr(self.frame_provider, "camera_key", None)
|
||||
if default and default in all_cameras:
|
||||
return [default]
|
||||
# ``restrict_to_default_camera`` is set but the configured default
|
||||
# isn't one the provider exposes. Returning it anyway would make
|
||||
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
|
||||
# fall through to every available camera instead.
|
||||
if default:
|
||||
logging.getLogger(__name__).warning(
|
||||
"restrict_to_default_camera is set but camera_key=%r is not in the "
|
||||
"provider's cameras %s; grounding VQA on all available cameras instead.",
|
||||
default,
|
||||
all_cameras,
|
||||
)
|
||||
return all_cameras
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("interjections_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("interjections_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
first_ts = float(frame_timestamps[0])
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(first_ts, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -1,780 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
null_provider,
|
||||
to_contact_sheet_blocks,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Prepended to every describe / segment prompt so the VLM knows the images are
|
||||
# timestamped contact-sheet grids, not a single video, and reads the burned-in
|
||||
# per-tile timestamp when choosing boundaries.
|
||||
def _contact_sheet_preamble(columns: int) -> str:
|
||||
return (
|
||||
"CONTACT SHEETS — how to read the images below:\n"
|
||||
f"- Each image is a grid of sampled video frames, {columns} per row, "
|
||||
"with time running left-to-right then top-to-bottom (row-major).\n"
|
||||
"- Each frame has its timestamp burned into the top-left corner, e.g. "
|
||||
'"012.50s". Use that printed timestamp (not the tile position) when you '
|
||||
"choose start/end times; boundaries should land on or near a printed "
|
||||
"timestamp.\n"
|
||||
"- Frames continue across grids: an action may span the end of one sheet "
|
||||
"and the start of the next, so do not place a boundary just because a new "
|
||||
"image begins.\n\n"
|
||||
)
|
||||
|
||||
|
||||
# Appended to every describe (and segment) prompt. A visual, causal definition
|
||||
# of where one event ends and the next begins — adapted from macrodata/refiner —
|
||||
# to sharpen cut points while the existing prompt keeps owning the imperative
|
||||
# phrasing.
|
||||
_CAUSAL_BOUNDARY_RULES = (
|
||||
"EVENT BOUNDARIES — where one event ends and the next begins:\n"
|
||||
"- Start a new event whenever the world state changes: an object becomes "
|
||||
"held (the gripper closes on it), an object is released (the gripper opens "
|
||||
"and it stays put), an object reaches a new location, a lid/door/drawer "
|
||||
"changes open/closed state, a tool starts or stops affecting a surface, or "
|
||||
"contents visibly move (e.g. poured).\n"
|
||||
"- If a single action changes the same state gradually and continuously, "
|
||||
"keep it as ONE event — do not split it.\n"
|
||||
"- If the same action repeats on different objects or target locations, "
|
||||
"treat each repetition as a separate event.\n"
|
||||
"- Do NOT create boundaries for idle time, camera motion, hesitation, or "
|
||||
"tiny hand adjustments."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Task driving every plan-module prompt: canonical episode_task, or a
|
||||
# video-derived one when it's empty/placeholder (see derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
|
||||
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
|
||||
# free-form n_task_rephrasings; the effective task is always emitted
|
||||
# first so the rotation covers the source-of-truth phrasing.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
variants: list[str] | None = None
|
||||
if self.config.task_aug_axes.enabled and effective_task:
|
||||
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
|
||||
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
if variants is not None:
|
||||
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary (incl. t=0). The plan is a
|
||||
# numbered list of still-todo subtasks, so re-emitting at each
|
||||
# boundary makes it shrink as work progresses — ${plan} at frame t is
|
||||
# exactly what's left to do.
|
||||
if self.config.emit_plan:
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start;
|
||||
# skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only).
|
||||
prior_memory = ""
|
||||
memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else []
|
||||
for i, span in memory_boundaries:
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
@staticmethod
|
||||
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
|
||||
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
|
||||
seen: set[str] = set()
|
||||
rows: list[dict[str, Any]] = []
|
||||
for phrasing in phrasings:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
|
||||
)
|
||||
return rows
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers — every plan-module prompt follows the same shape:
|
||||
# build messages → single VLM call → pull a named field.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prompt: str,
|
||||
window: tuple[float, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""User message combining the (optionally windowed) contact sheets with ``prompt``.
|
||||
|
||||
The prompt is always prefixed with a short explanation of how to read
|
||||
the timestamped grids, so the model treats them as one ordered
|
||||
sequence of frames rather than unrelated images.
|
||||
"""
|
||||
prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt
|
||||
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
|
||||
"""One VLM call → variants along the 5-axis taxonomy.
|
||||
|
||||
Variants from all axes are flattened into a single list (the
|
||||
downstream pipeline doesn't need to know about the per-axis
|
||||
bucketing — every variant becomes a ``task_aug`` row). Order
|
||||
is preserved for reproducibility: synonym_paraphrase first,
|
||||
then omit_arm, then omit_orientation, then omit_grasp_method,
|
||||
then combined_omissions.
|
||||
"""
|
||||
if not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_aug_axes").format(
|
||||
base_task=base_task,
|
||||
n_synonym=axes_cfg.synonym_paraphrase,
|
||||
n_omit_arm=axes_cfg.omit_arm,
|
||||
n_omit_orientation=axes_cfg.omit_orientation,
|
||||
n_omit_grasp_method=axes_cfg.omit_grasp_method,
|
||||
n_combined=axes_cfg.combined_omissions,
|
||||
)
|
||||
result = self.vlm.generate_json([self._text_message(prompt)])[0]
|
||||
if not isinstance(result, dict):
|
||||
return []
|
||||
ordered_axes = (
|
||||
"synonym_paraphrase",
|
||||
"omit_arm",
|
||||
"omit_orientation",
|
||||
"omit_grasp_method",
|
||||
"combined_omissions",
|
||||
)
|
||||
flat: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for axis in ordered_axes:
|
||||
entries = result.get(axis)
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for item in entries:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
key = item.strip().strip('"').strip("'")
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
flat.append(key)
|
||||
return flat
|
||||
|
||||
def _episode_video_block(
|
||||
self, record: EpisodeRecord, window: tuple[float, float] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Timestamped contact sheets for the describe / segmentation prompts.
|
||||
|
||||
Always renders the (optionally windowed) episode as contact sheets:
|
||||
frames sampled at ``frames_per_second`` and packed into timestamped
|
||||
JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole
|
||||
episodes that exceed it are windowed upstream in
|
||||
:meth:`_generate_subtasks` so each call stays within budget while the
|
||||
full episode keeps its sampling density.
|
||||
|
||||
When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE
|
||||
(``ts - w0``) to match the window-relative time frame the
|
||||
segmentation prompt works in (spans are offset back to absolute time
|
||||
afterwards).
|
||||
"""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if window is not None:
|
||||
w0, w1 = float(window[0]), float(window[1])
|
||||
dur = max(0.0, w1 - w0)
|
||||
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_frames_per_prompt)
|
||||
if n <= 1 or dur <= 0.0:
|
||||
timestamps = [0.5 * (w0 + w1)]
|
||||
else:
|
||||
step = dur / (n - 1)
|
||||
timestamps = [w0 + i * step for i in range(n)]
|
||||
frames = self.frame_provider.frames_at(record, timestamps)
|
||||
rel = [ts - w0 for ts in timestamps[: len(frames)]]
|
||||
return self._contact_sheet_blocks(frames, rel)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_frames_per_prompt)
|
||||
timestamps = self._uniform_episode_timestamps(record, n)
|
||||
frames = self.frame_provider.frames_at(record, timestamps)
|
||||
return self._contact_sheet_blocks(frames, timestamps[: len(frames)])
|
||||
|
||||
@staticmethod
|
||||
def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]:
|
||||
"""``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly."""
|
||||
ts = record.frame_timestamps
|
||||
if n >= len(ts):
|
||||
return [float(t) for t in ts]
|
||||
t0, t_last = float(ts[0]), float(ts[-1])
|
||||
if t_last <= t0 or n <= 1:
|
||||
return [t0] * max(1, n)
|
||||
step = (t_last - t0) / (n - 1)
|
||||
return [t0 + i * step for i in range(n)]
|
||||
|
||||
def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]:
|
||||
"""Build timestamped contact-sheet image blocks from decoded frames."""
|
||||
return to_contact_sheet_blocks(
|
||||
frames,
|
||||
timestamps,
|
||||
columns=self.config.contact_sheet_columns,
|
||||
frames_per_sheet=self.config.contact_sheet_frames_per_sheet,
|
||||
frame_width=self.config.contact_sheet_frame_width,
|
||||
quality=self.config.contact_sheet_quality,
|
||||
)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections (event-driven). The
|
||||
interjection text is forwarded into the prompt so the refreshed plan
|
||||
reflects the user's correction.
|
||||
"""
|
||||
if not self.config.emit_plan:
|
||||
return
|
||||
existing = staging.read("plan")
|
||||
# Pass the last frame timestamp so the final span is closed (else its
|
||||
# end == start, zero duration, and a refresh inside it is missed).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Generate subtask spans, optionally via a multi-call quality chain.
|
||||
|
||||
Single call (default): watch video → emit subtask JSON.
|
||||
|
||||
Multi-call (opt-in, higher quality, more VLM calls):
|
||||
1. ``subtask_describe_first`` — a grounding pass that narrates
|
||||
ONLY what is visible (no JSON commitment to subtasks yet);
|
||||
its description is injected into the segmentation prompt so
|
||||
the model segments its own grounded observations instead of
|
||||
pattern-matching the task text.
|
||||
2. segmentation — emit subtask JSON (as before).
|
||||
"""
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
effective_task = task if task is not None else record.episode_task
|
||||
|
||||
# ---- Auto-windowing (keeps the full sampling density) --------
|
||||
# Contact sheets are cheap, but a whole long episode sampled at
|
||||
# ``frames_per_second`` can still exceed ``max_frames_per_prompt``.
|
||||
# When it does, split into consecutive windows of exactly that many
|
||||
# frames (one describe→segment call each, still at the full sampling
|
||||
# density), then merge + stitch — so an episode of any length is
|
||||
# covered at full density rather than subsampled into one sparse call.
|
||||
fps = max(1e-6, float(self.config.frames_per_second))
|
||||
n_whole = int(round(episode_duration * fps)) + 1
|
||||
if n_whole > self.config.max_frames_per_prompt:
|
||||
window_s = self.config.max_frames_per_prompt / fps
|
||||
return self._generate_subtasks_windowed(record, effective_task, window_s)
|
||||
|
||||
# ---- Pass 1 (optional): grounding description ----------------
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, effective_task)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the video) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
# ---- Pass 2: segmentation ------------------------------------
|
||||
prompt = self._with_causal_rules(
|
||||
load_prompt("plan_subtasks").format(
|
||||
episode_task=effective_task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
cleaned = self._clean_spans(spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# ---- Full-episode coverage stitch ----------------------------
|
||||
# The VLM can start after t0 or leave gaps, so frames fall through
|
||||
# with no active subtask. Always stitch into a contiguous
|
||||
# [t0, t_last] cover.
|
||||
cleaned = self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _generate_subtasks_windowed(
|
||||
self, record: EpisodeRecord, task: str, window_s: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Subtask generation in fixed-length windows at constant fps.
|
||||
|
||||
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
|
||||
seconds, runs the describe -> segment chain on each window's own
|
||||
frames (sampled at ``frames_per_second``), offsets
|
||||
each window's spans back to absolute episode time, then merges +
|
||||
stitches into a contiguous whole-episode cover.
|
||||
"""
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
all_spans: list[dict[str, Any]] = []
|
||||
w0 = t0
|
||||
n_windows = 0
|
||||
while w0 < t_last - 1e-6:
|
||||
w1 = min(w0 + window_s, t_last)
|
||||
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
|
||||
n_windows += 1
|
||||
w0 = w1
|
||||
logger.info(
|
||||
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
|
||||
record.episode_index,
|
||||
n_windows,
|
||||
window_s,
|
||||
len(all_spans),
|
||||
)
|
||||
# Merge across windows: clamp to the absolute episode, sort, and
|
||||
# frame-snap to distinct starts (handles any boundary collisions).
|
||||
cleaned = self._clean_spans(all_spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
return self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
def _subtasks_for_window(
|
||||
self, record: EpisodeRecord, task: str, w0: float, w1: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run describe -> segment on one ``[w0, w1]`` window.
|
||||
|
||||
The model works in window-RELATIVE time ``[0, L]`` (it perceives
|
||||
the window as a clip starting at 0); spans are offset back to
|
||||
absolute ``[w0, w1]`` before returning.
|
||||
"""
|
||||
window = (w0, w1)
|
||||
win_len = max(0.0, w1 - w0)
|
||||
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, task, window=window)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video clip and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the clip) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
prompt = self._with_causal_rules(
|
||||
load_prompt("plan_subtasks").format(
|
||||
episode_task=task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{win_len:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||
# Window-relative clamp; no frame-snap dedupe yet (done on the
|
||||
# merged absolute set).
|
||||
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# Offset window-relative spans back to absolute episode time.
|
||||
for s in cleaned:
|
||||
s["start"] = w0 + float(s["start"])
|
||||
s["end"] = w0 + float(s["end"])
|
||||
return cleaned
|
||||
|
||||
def _stitch_full_coverage(
|
||||
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Make subtask spans tile the full episode with no gaps.
|
||||
|
||||
* The first subtask starts at the episode's first frame ``t0``
|
||||
(any idle / approach before the first labelled action is folded
|
||||
into it), so every early frame has an active subtask.
|
||||
* Each subtask's ``end`` is snapped to the next subtask's
|
||||
``start`` (gaps between spans are closed), and the final
|
||||
subtask's ``end`` extends to the last frame ``t_last``.
|
||||
|
||||
Starts are otherwise left as the (already frame-snapped, distinct)
|
||||
values the VLM produced — only the FIRST start is pulled
|
||||
back to ``t0``, which can't collide with a later span because it
|
||||
was already the earliest. Purely deterministic; runs after the
|
||||
VLM passes.
|
||||
"""
|
||||
if not spans or not record.frame_timestamps:
|
||||
return spans
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
spans = sorted(spans, key=lambda s: float(s["start"]))
|
||||
spans[0]["start"] = t0
|
||||
for i in range(len(spans) - 1):
|
||||
spans[i]["end"] = float(spans[i + 1]["start"])
|
||||
spans[-1]["end"] = t_last
|
||||
for s in spans:
|
||||
if float(s["end"]) < float(s["start"]):
|
||||
s["end"] = float(s["start"])
|
||||
return spans
|
||||
|
||||
@staticmethod
|
||||
def _with_causal_rules(prompt: str) -> str:
|
||||
"""Append the causal event-boundary rules to a describe/segment prompt."""
|
||||
return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}"
|
||||
|
||||
def _clean_spans(
|
||||
self,
|
||||
spans: Any,
|
||||
record: EpisodeRecord,
|
||||
bounds: tuple[float, float] | None = None,
|
||||
dedupe: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
|
||||
|
||||
``bounds`` overrides the clamp range — pass the window's
|
||||
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
|
||||
``None`` to clamp to the whole episode ``[t0, t_last]``.
|
||||
``dedupe`` runs the frame-snap distinct-start step; skip it for
|
||||
window-relative spans (frame snapping is done once on the merged,
|
||||
absolute-time set).
|
||||
"""
|
||||
if not spans:
|
||||
return []
|
||||
if bounds is not None:
|
||||
lo, hi = float(bounds[0]), float(bounds[1])
|
||||
else:
|
||||
lo = record.frame_timestamps[0]
|
||||
hi = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(lo, min(start, hi))
|
||||
end = max(lo, min(end, hi))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
if dedupe:
|
||||
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
return cleaned
|
||||
|
||||
def _describe_episode(
|
||||
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
|
||||
) -> str:
|
||||
"""Grounding pass: free-form chronological description of the (windowed) video."""
|
||||
prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task))
|
||||
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else ""
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
No VLM call: a plain numbered list keeps the plan aligned with the
|
||||
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
|
||||
cost a round-trip per episode/refresh and could diverge).
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
|
||||
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
|
||||
interjections, and ``run_episode`` at each boundary), only subtasks
|
||||
starting at or after ``refresh_t`` are included — so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("plan_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -1,12 +0,0 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -1,46 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -1,36 +0,0 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -1,27 +0,0 @@
|
||||
You are watching a teleoperated robot demonstration from a single
|
||||
camera. The user asked the robot to: "{episode_task}"
|
||||
|
||||
This is an OBSERVATION pass. Watch the entire clip and describe, in
|
||||
chronological order, ONLY what the robot physically does — the concrete
|
||||
motions, approaches, contacts, grasps, releases, and relocations you can
|
||||
actually SEE in the frames.
|
||||
|
||||
Hard rules:
|
||||
- Describe only motion visible in the video. Do NOT use the task
|
||||
instruction to guess steps that aren't shown. The instruction is the
|
||||
goal; the video is ground truth.
|
||||
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
|
||||
the single field below. Just narrate what happens.
|
||||
- Give an approximate timestamp (in seconds) for each distinct event,
|
||||
e.g. "0.0-1.4s: the base drives forward toward the stove".
|
||||
- Do NOT invent objects, grasps, destinations, or steps. If the robot
|
||||
only does one thing (e.g. it just navigates and the clip ends), say
|
||||
exactly that and nothing more.
|
||||
- Be concrete and literal. "the gripper closes on the mug" — not "the
|
||||
robot prepares to make coffee".
|
||||
|
||||
Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"description": "<chronological, timestamped description of ONLY what is visible>"
|
||||
}}
|
||||
@@ -1,112 +0,0 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{observation_block}GROUNDING — read this first, it overrides everything below:
|
||||
- Label ONLY what the robot actually does in the video. Every subtask
|
||||
you emit must correspond to motion you can SEE in specific frames.
|
||||
- Do NOT invent, anticipate, or pad. If the robot only does one thing
|
||||
(e.g. it just navigates to a location and the clip ends), emit
|
||||
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
|
||||
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
|
||||
subtasks than the ceiling is not just allowed, it is expected for
|
||||
short / atomic demonstrations. One correct subtask is far better
|
||||
than several invented ones.
|
||||
- If the video does not clearly show the action implied by the task,
|
||||
describe what you actually see — do NOT fabricate the task's steps
|
||||
from the instruction text. The instruction tells you the goal; the
|
||||
VIDEO is the ground truth for what happened.
|
||||
|
||||
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
|
||||
fires ONLY on the spatial situation it names; it must not change how you
|
||||
label any other situation):
|
||||
- STACK vs PUT: if an object is placed ON TOP OF another specific object
|
||||
(not on a flat table / shelf / counter), use "stack ... on ...", not
|
||||
"put". "stack blue book on green book", NOT "put blue book on table".
|
||||
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
|
||||
receptacle (push-fit), use "insert ... into ...", not "put".
|
||||
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
|
||||
on the object and the object moves WITH the hand, it is "pick up" /
|
||||
"retrieve" (object leaves its location). If the gripper OPENS and the
|
||||
object stays where the hand left it, it is "put" / "place" (object
|
||||
arrives at a location). Decide by which way the object moves, not by
|
||||
where the hand ends up.
|
||||
- POUR vs PUT: only use "pour" when the source is tilted and contents
|
||||
flow out; moving a full container without tilting is "put"/"place".
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,67 +0,0 @@
|
||||
You are generating structured augmentations of a robot task instruction
|
||||
for training a language-conditioned policy. Unlike free-form rephrasing,
|
||||
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
|
||||
a specific element of the task while preserving its meaning.
|
||||
|
||||
Original task: "{base_task}"
|
||||
|
||||
Produce variants along five named axes. Each axis has a target count.
|
||||
The whole batch should expose the policy to maximum linguistic diversity
|
||||
WITHOUT changing what the robot is supposed to do.
|
||||
|
||||
Axes and target counts:
|
||||
|
||||
synonym_paraphrase ({n_synonym}):
|
||||
Different wording / verbs / sentence structure. ALL information
|
||||
from the original task is preserved — same object, same arm
|
||||
specification if present, same orientation if present, same grasp
|
||||
if present.
|
||||
|
||||
omit_arm ({n_omit_arm}):
|
||||
Drop the left/right/both arm specification from the task. Skip
|
||||
entirely (emit 0 entries) if the original task does NOT mention an
|
||||
arm. Do not invent an arm specification just to omit it.
|
||||
|
||||
omit_orientation ({n_omit_orientation}):
|
||||
Drop orientation cues (upright, sideways, facing the user,
|
||||
long-edge-first, etc.). Skip entirely if no orientation cue is
|
||||
present in the original task.
|
||||
|
||||
omit_grasp_method ({n_omit_grasp_method}):
|
||||
Drop the grip / grasp method specification (pinch, wrap, hold by
|
||||
the rim, etc.). Skip entirely if no grasp method is mentioned.
|
||||
|
||||
combined_omissions ({n_combined}):
|
||||
Combine TWO of the above omissions simultaneously (e.g. drop both
|
||||
arm and orientation). Skip entirely if fewer than two of (arm,
|
||||
orientation, grasp_method) appear in the original task.
|
||||
|
||||
Hard rules:
|
||||
- Each variant MUST preserve the core action, the target object, AND
|
||||
the goal / destination. Do not change which object is involved, where
|
||||
it goes, or the high-level action. "Navigate to the stove" may become
|
||||
"go to the stove" or "head over to the stove" — it must NEVER become
|
||||
"wander around the kitchen", "explore the room", or anything that
|
||||
drops or generalises the stove destination. If you cannot vary the
|
||||
wording without changing the goal, emit fewer variants.
|
||||
- Only the FIVE listed elements (wording, arm, orientation, grasp
|
||||
method, or a combination) may be varied or omitted. The verb's
|
||||
meaning, the object, and the destination are fixed.
|
||||
- Each variant is plain prose, no markdown, no quotes, no list numbers.
|
||||
- Each variant must be DISTINCT from every other variant in the entire
|
||||
output, both within and across axes. Near-duplicates are not allowed.
|
||||
- If an axis cannot reach its target count because the original task
|
||||
lacks the omittable element, emit fewer entries — do NOT pad the
|
||||
axis with paraphrases that belong to a different axis.
|
||||
- Variants should not all start with verbs — vary sentence structure
|
||||
(some imperative, some polite request, some question).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"synonym_paraphrase": ["<v1>", "<v2>", ...],
|
||||
"omit_arm": ["<v1>", "<v2>", ...],
|
||||
"omit_orientation": ["<v1>", ...],
|
||||
"omit_grasp_method": ["<v1>", ...],
|
||||
"combined_omissions": ["<v1>", ...]
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,17 +0,0 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
@@ -1,216 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
@@ -1,332 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
|
||||
# ``_check_column_routing`` already recorded any unknown-style error;
|
||||
# don't let the same ``column_for_style`` lookup raise here uncaught.
|
||||
try:
|
||||
column = column_for_style(row.get("style"))
|
||||
except ValueError:
|
||||
continue
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
|
||||
return
|
||||
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
@@ -1,617 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output from the server,
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client.
|
||||
|
||||
Only the ``openai`` backend is supported for now. The shipped workflow
|
||||
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
|
||||
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
|
||||
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
|
||||
optionally auto-spawning the server via ``auto_serve`` /
|
||||
``serve_command``). The former in-process ``vllm`` / ``transformers``
|
||||
backends were removed to keep the support surface to the HF Jobs path.
|
||||
|
||||
For ``stub``, construct :class:`StubVlmClient` directly with a responder
|
||||
callable; it is rejected here to make accidental misuse obvious.
|
||||
"""
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend in {"vllm", "transformers"}:
|
||||
raise ValueError(
|
||||
f"backend={config.backend!r} (in-process local model) is not supported for now — "
|
||||
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
|
||||
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
|
||||
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
|
||||
)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _bind_serve_port(cmd: str, port: int) -> str:
|
||||
"""Bind a serve command to ``port``: substitute a ``{port}`` placeholder
|
||||
if present, else append ``--port`` when the command omits it (leaving an
|
||||
explicit ``--port`` untouched). Shared by the single- and parallel-server
|
||||
paths so a serve_command never reaches the server with a literal
|
||||
``{port}``."""
|
||||
if "{port}" in cmd:
|
||||
return cmd.replace("{port}", str(port))
|
||||
if "--port" not in cmd:
|
||||
return f"{cmd} --port {port}"
|
||||
return cmd
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = _bind_serve_port(base_cmd, port)
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
# Bind the single server to ``serve_port`` (what ``api_base`` below
|
||||
# targets): substitute a literal ``{port}`` placeholder, else append
|
||||
# ``--port``. Without this a serve_command carrying ``{port}`` would
|
||||
# reach the server unsubstituted and fail to parse.
|
||||
cmd = _bind_serve_port(cmd, config.serve_port)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
@@ -1,341 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the language-column struct shape.
|
||||
|
||||
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
|
||||
writer infers the parquet struct schema from insertion order, so
|
||||
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
|
||||
"""
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
out: dict[str, Any] = {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
}
|
||||
if with_timestamp:
|
||||
out["timestamp"] = float(row["timestamp"])
|
||||
out["camera"] = None if camera is None else str(camera)
|
||||
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Friendly error from the writer instead of a raw KeyError below;
|
||||
# the validator doesn't check ``role`` yet.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=True)
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=False)
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(
|
||||
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
|
||||
) -> None:
|
||||
state: dict = {"step": step}
|
||||
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
|
||||
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
|
||||
# produced `step`, since both scale how many sampler positions a step consumes (see
|
||||
# compute_sampler_state).
|
||||
if num_processes is not None:
|
||||
state["num_processes"] = num_processes
|
||||
if batch_size is not None:
|
||||
state["batch_size"] = batch_size
|
||||
write_json(state, save_dir / TRAINING_STEP)
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||
|
||||
|
||||
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
|
||||
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
@@ -96,8 +75,6 @@ def save_checkpoint(
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -123,10 +100,6 @@ def save_checkpoint(
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
@@ -139,9 +112,7 @@ def save_checkpoint(
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(
|
||||
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||
)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
@@ -149,8 +120,6 @@ def save_training_state(
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -162,12 +131,10 @@ def save_training_state(
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
|
||||
@@ -73,8 +73,6 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||
use_async_envs: bool = True
|
||||
# Whether to record eval rollouts as a LeRobot v3.0 dataset on disk.
|
||||
recording: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size == 0:
|
||||
|
||||
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
@@ -82,7 +82,6 @@ __all__ = [
|
||||
"aggregate_stats",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"compute_sampler_state",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
|
||||
@@ -286,8 +286,6 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
chunk_size: int | None = None,
|
||||
concatenate_videos: bool = True,
|
||||
concatenate_data: bool = True,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
|
||||
@@ -305,8 +303,6 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
@@ -355,12 +351,8 @@ def aggregate_datasets(
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(
|
||||
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
|
||||
)
|
||||
data_idx = aggregate_data(
|
||||
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
|
||||
)
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
@@ -375,9 +367,7 @@ def aggregate_datasets(
|
||||
logging.info("Aggregation complete.")
|
||||
|
||||
|
||||
def aggregate_videos(
|
||||
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
|
||||
):
|
||||
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||
|
||||
Handles video file concatenation and rotation based on file size limits.
|
||||
@@ -389,7 +379,6 @@ def aggregate_videos(
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -450,7 +439,7 @@ def aggregate_videos(
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new file - offset is 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
@@ -488,7 +477,7 @@ def aggregate_videos(
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
@@ -504,7 +493,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
data_files_size_in_mb: Maximum size for data files in MB.
|
||||
chunk_size: Maximum number of files per chunk.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
@@ -550,7 +538,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
concatenate=concatenate_data,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
@@ -627,7 +614,6 @@ def append_or_create_parquet_file(
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
concatenate: bool = True,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -644,7 +630,6 @@ def append_or_create_parquet_file(
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
@@ -664,7 +649,7 @@ def append_or_create_parquet_file(
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if not concatenate or dst_size + src_size >= max_mb:
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
|
||||
@@ -59,8 +59,6 @@ class RunningQuantileStats:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
# Promote integer and low-precision inputs before computing squared statistics.
|
||||
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
|
||||
@@ -261,8 +261,6 @@ def merge_datasets(
|
||||
datasets: list[LeRobotDataset],
|
||||
output_repo_id: str,
|
||||
output_dir: str | Path | None = None,
|
||||
concatenate_videos: bool = True,
|
||||
concatenate_data: bool = True,
|
||||
) -> LeRobotDataset:
|
||||
"""Merge multiple LeRobotDatasets into a single dataset.
|
||||
|
||||
@@ -272,8 +270,6 @@ def merge_datasets(
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Merged dataset identifier.
|
||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
@@ -288,8 +284,6 @@ def merge_datasets(
|
||||
aggr_repo_id=output_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=output_dir,
|
||||
concatenate_videos=concatenate_videos,
|
||||
concatenate_data=concatenate_data,
|
||||
)
|
||||
|
||||
merged_dataset = LeRobotDataset(
|
||||
|
||||
+32
-122
@@ -14,36 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
"""Sampler over episode frames that stores only per-episode boundaries.
|
||||
|
||||
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
|
||||
instead of materializing a Python list of every frame index.
|
||||
|
||||
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
|
||||
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
|
||||
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
|
||||
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
|
||||
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
|
||||
resumed epoch, `__len__` still reports the full length.
|
||||
|
||||
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
|
||||
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
|
||||
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
|
||||
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
|
||||
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
|
||||
starts (e.g. on resume or in tests).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
@@ -52,125 +30,57 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
):
|
||||
"""
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: Start index of each episode in the dataset.
|
||||
dataset_to_indices: End index of each episode in the dataset.
|
||||
episode_indices_to_use: Episode indices to use; None means all.
|
||||
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
seed: Seed the permutation is derived from (together with the epoch).
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||
if from_indices.shape != to_indices.shape:
|
||||
raise ValueError(
|
||||
f"dataset_from_indices and dataset_to_indices must have the same length, "
|
||||
f"got {len(from_indices)} and {len(to_indices)}"
|
||||
)
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
used = np.ones(len(from_indices), dtype=bool)
|
||||
if episode_indices_to_use is not None:
|
||||
used = np.zeros(len(from_indices), dtype=bool)
|
||||
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
|
||||
|
||||
starts = from_indices + drop_n_first_frames
|
||||
lengths = to_indices - drop_n_last_frames - starts
|
||||
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
to_indices[episode_idx] - from_indices[episode_idx],
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
used &= lengths > 0
|
||||
if not used.any():
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self._starts = starts[used]
|
||||
self._cum_lengths = np.cumsum(lengths[used])
|
||||
self._num_frames = int(self._cum_lengths[-1])
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||
|
||||
def load_state_dict(self, state: dict) -> None:
|
||||
self._epoch = state["epoch"]
|
||||
self._start_index = state["start_index"]
|
||||
|
||||
def _epoch_generator(self, epoch: int) -> torch.Generator:
|
||||
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
|
||||
# and reproduces identically on every rank without touching the global RNG.
|
||||
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
|
||||
return torch.Generator().manual_seed(epoch_seed)
|
||||
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||
return int(self._starts[episode]) + position_in_episode
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||
epoch, start = self._epoch, self._start_index
|
||||
self._epoch += 1
|
||||
self._start_index = 0
|
||||
return self._iter_epoch(epoch, start)
|
||||
|
||||
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(int(order[k]))
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(k)
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_frames
|
||||
|
||||
|
||||
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||
|
||||
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||
per epoch (`even_batches` padding included). The start index provably stays below
|
||||
`num_frames`; the `min` is defensive.
|
||||
|
||||
Assumptions (resume is only sample-exact when they hold):
|
||||
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
|
||||
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||
the boundary is off.
|
||||
"""
|
||||
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
|
||||
return {"epoch": epoch, "start_index": start_index}
|
||||
return len(self.indices)
|
||||
|
||||
@@ -481,10 +481,8 @@ def reencode_video(
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
start_time_s: float | None = None,
|
||||
end_time_s: float | None = None,
|
||||
) -> None:
|
||||
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
|
||||
"""Re-encode a video file using the given encoder configuration.
|
||||
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
@@ -493,17 +491,10 @@ def reencode_video(
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
start_time_s: When set, trim the output to start at this timestamp (seconds).
|
||||
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
|
||||
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
|
||||
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
|
||||
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
|
||||
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
@@ -535,10 +526,6 @@ def reencode_video(
|
||||
width = int(in_stream.width)
|
||||
height = int(in_stream.height)
|
||||
|
||||
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
|
||||
if start_time_s is not None:
|
||||
src.seek(int(start_time_s * av.time_base), backward=True)
|
||||
|
||||
with av.open(
|
||||
tmp_output_video_path,
|
||||
mode="w",
|
||||
@@ -552,14 +539,7 @@ def reencode_video(
|
||||
out_stream.height = height
|
||||
|
||||
for frame in src.decode(in_stream):
|
||||
frame_time_s = frame.time
|
||||
if start_time_s is not None and frame_time_s < start_time_s:
|
||||
continue
|
||||
if end_time_s is not None and frame_time_s >= end_time_s:
|
||||
break
|
||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||
if start_time_s is not None:
|
||||
frame.pts = None # reset timestamps so the trimmed output starts at t=0
|
||||
packet = out_stream.encode(frame)
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
@@ -29,7 +29,6 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.__version__ import __version__
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
@@ -39,67 +38,6 @@ from .utils import log_model_loading_keys
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
|
||||
def _build_card_context(
|
||||
cfg: TrainPipelineConfig | None,
|
||||
dataset_repo_id: str | None,
|
||||
input_features: dict | None,
|
||||
output_features: dict | None,
|
||||
) -> dict:
|
||||
"""Collect optional data for the model-card template.
|
||||
|
||||
Returns plain values only (no Markdown) — the template in
|
||||
``lerobot/templates/lerobot_modelcard_template.md`` decides how and whether to show
|
||||
each one. Everything is best-effort: anything unavailable is left empty/None and the
|
||||
template simply skips that section, so this never breaks a Hub push.
|
||||
"""
|
||||
context = {
|
||||
"training": None,
|
||||
"input_features": input_features or {},
|
||||
"output_features": output_features or {},
|
||||
"dataset": None,
|
||||
"robot_type": None,
|
||||
"cameras": [],
|
||||
}
|
||||
|
||||
if cfg is not None:
|
||||
optimizer = getattr(cfg, "optimizer", None)
|
||||
context["training"] = {
|
||||
"steps": cfg.steps,
|
||||
"batch_size": cfg.batch_size,
|
||||
"seed": cfg.seed,
|
||||
"optimizer": getattr(optimizer, "type", None) if optimizer else None,
|
||||
"lr": getattr(optimizer, "lr", None) if optimizer else None,
|
||||
"lerobot_version": __version__,
|
||||
}
|
||||
|
||||
if dataset_repo_id:
|
||||
dataset_cfg = getattr(cfg, "dataset", None)
|
||||
try:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(
|
||||
dataset_repo_id,
|
||||
root=getattr(dataset_cfg, "root", None),
|
||||
revision=getattr(dataset_cfg, "revision", None),
|
||||
)
|
||||
context["dataset"] = {
|
||||
"repo_id": dataset_repo_id,
|
||||
"episodes": meta.total_episodes,
|
||||
"frames": meta.total_frames,
|
||||
"fps": meta.fps,
|
||||
"tasks": [str(task) for task in meta.tasks.index],
|
||||
}
|
||||
context["robot_type"] = meta.robot_type
|
||||
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
|
||||
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
|
||||
logging.warning(
|
||||
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
|
||||
f"omitted from the model card. ({e})"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
noise: Tensor | None
|
||||
|
||||
@@ -290,7 +228,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
@@ -308,20 +246,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def generate_model_card(
|
||||
self,
|
||||
dataset_repo_id: str,
|
||||
model_type: str,
|
||||
license: str | None,
|
||||
tags: list[str] | None,
|
||||
cfg: TrainPipelineConfig | None = None,
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
base_model_mapping = {
|
||||
"smolvla": "lerobot/smolvla_base",
|
||||
"pi0": "lerobot/pi0_base",
|
||||
"pi05": "lerobot/pi05_base",
|
||||
"pi0_fast": "lerobot/pi0fast-base",
|
||||
"xvla": "lerobot/xvla-base",
|
||||
}
|
||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
||||
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
@@ -330,20 +257,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
base_model=base_model_mapping.get(model_type),
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
context = _build_card_context(
|
||||
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
|
||||
)
|
||||
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
|
||||
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
|
||||
context["base_model"] = base_model_mapping.get(model_type)
|
||||
|
||||
template_card = (
|
||||
files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8")
|
||||
)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card, **context)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig as DistributionalVFConfig,
|
||||
)
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"DistributionalVFConfig",
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
|
||||
+7
-9
@@ -1,6 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,12 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .modeling_distributional_value_function import DistributionalVFRewardModel
|
||||
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
"DistributionalVFConfig",
|
||||
"DistributionalVFRewardModel",
|
||||
"make_distributional_vf_pre_post_processors",
|
||||
]
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
Weights initialized from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("distributional_value_function")
|
||||
@dataclass
|
||||
class DistributionalVFConfig(RewardModelConfig):
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
|
||||
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
|
||||
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
|
||||
(Eq. 1 in the paper).
|
||||
|
||||
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
|
||||
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
|
||||
about 670M params and initialize from the PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
# Backbone
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
num_hidden_layers: int = 6
|
||||
num_vision_layers: int = 13
|
||||
|
||||
# Distributional head
|
||||
num_value_bins: int = 201
|
||||
value_support_min: float = -1.0
|
||||
value_support_max: float = 0.0
|
||||
hl_gauss_sigma_ratio: float = 5.0
|
||||
|
||||
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
|
||||
target_method: str = "hl_gauss"
|
||||
|
||||
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
|
||||
# When False, terminal states use the same target method as non-terminal states.
|
||||
use_one_hot_terminal: bool = True
|
||||
|
||||
# Image
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 64
|
||||
|
||||
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
|
||||
# Pass a PI05 checkpoint path or Hub repo_id here.
|
||||
# After training, load the value function with RewardModel.from_pretrained() instead.
|
||||
init_from_actor_path: str = ""
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=3e-4,
|
||||
weight_decay=1e-4,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.input_features:
|
||||
return
|
||||
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
|
||||
if not has_image:
|
||||
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
|
||||
+567
@@ -0,0 +1,567 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Modeling for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Inputs: single image observation + task text prompt ("Task: {task}.")
|
||||
Outputs: softmax distribution over value bins; expected value E[V] for inference.
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
|
||||
Weight initialization: vision tower, multi-modal projector, token embeddings, and
|
||||
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaRMSNorm,
|
||||
_gated_residual,
|
||||
_get_pi_gemma_decoder_layer_base,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
PiGemmaRMSNorm = None
|
||||
_gated_residual = None
|
||||
_get_pi_gemma_decoder_layer_base = None
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257152
|
||||
|
||||
|
||||
class DistributionalVFRewardModel(PreTrainedRewardModel):
|
||||
"""Distributional value function model for RECAP.
|
||||
|
||||
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
|
||||
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
|
||||
per-task normalized Monte Carlo returns.
|
||||
|
||||
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
|
||||
causal attention, [CLS] token, and Linear(D, num_bins) value head.
|
||||
The expected value is E[V] = sum(softmax(logits) * bin_centers).
|
||||
"""
|
||||
|
||||
name = "distributional_value_function"
|
||||
config_class = DistributionalVFConfig
|
||||
|
||||
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
|
||||
require_package("transformers", extra="recap")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
|
||||
|
||||
# Get base dimensions from the paligemma variant (OpenPI config format)
|
||||
base_config = get_gemma_config(config.paligemma_variant)
|
||||
hidden_dim = base_config.width
|
||||
mlp_dim = base_config.mlp_dim
|
||||
num_layers = config.num_hidden_layers
|
||||
|
||||
# HuggingFace GemmaConfig for transformer layers
|
||||
gemma_config = CONFIG_MAPPING["gemma"](
|
||||
head_dim=base_config.head_dim,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=mlp_dim,
|
||||
num_attention_heads=base_config.num_heads,
|
||||
num_hidden_layers=num_layers,
|
||||
num_key_value_heads=base_config.num_kv_heads,
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
)
|
||||
self.gemma_config = gemma_config
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_value_bins = config.num_value_bins
|
||||
|
||||
# Single learned [CLS] token for value prediction
|
||||
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
||||
|
||||
# Value projection head: Linear(hidden_dim, num_bins)
|
||||
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
|
||||
|
||||
# Transformer layers (overwritten by _initialize_from_actor on first run)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
|
||||
|
||||
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
|
||||
# PaliGemmaConfig wraps both vision and text configs into a single model
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"]()
|
||||
paligemma_config.text_config = gemma_config
|
||||
paligemma_config.vision_config.image_size = config.image_resolution[0]
|
||||
paligemma_config.vision_config.intermediate_size = 4304
|
||||
paligemma_config.vision_config.projection_dim = 2048
|
||||
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
|
||||
|
||||
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
|
||||
self.vision_tower = paligemma_full.model.vision_tower
|
||||
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
|
||||
self.token_embedding = paligemma_full.model.language_model.embed_tokens
|
||||
del paligemma_full
|
||||
|
||||
# Truncate vision tower to num_vision_layers
|
||||
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
|
||||
vision_encoder = self.vision_tower.vision_model.encoder
|
||||
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
|
||||
|
||||
# Bin support: evenly spaced centers from value_support_min to value_support_max
|
||||
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
|
||||
self.register_buffer("bin_centers", bin_centers, persistent=False)
|
||||
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
|
||||
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
|
||||
|
||||
# Overwrite with pre-trained PI05 actor weights (first training run only)
|
||||
if config.init_from_actor_path:
|
||||
self._initialize_from_actor()
|
||||
|
||||
def _initialize_from_actor(self) -> None:
|
||||
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
|
||||
|
||||
Called on first training run only (when init_from_actor_path is set).
|
||||
"""
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
|
||||
actor_model = actor_policy.model
|
||||
|
||||
paligemma_model = actor_model.paligemma_with_expert.paligemma
|
||||
source_language_model = paligemma_model.model.language_model
|
||||
|
||||
# Transformer components
|
||||
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
|
||||
num_layers = self.gemma_config.num_hidden_layers
|
||||
for i in range(num_layers):
|
||||
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
|
||||
self.norm.load_state_dict(source_language_model.norm.state_dict())
|
||||
|
||||
# Vision tower (truncate source first, then copy)
|
||||
source_vision_tower = paligemma_model.model.vision_tower
|
||||
if hasattr(source_vision_tower, "vision_model") and hasattr(
|
||||
source_vision_tower.vision_model, "encoder"
|
||||
):
|
||||
source_encoder = source_vision_tower.vision_model.encoder
|
||||
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
|
||||
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
|
||||
|
||||
# Multi-modal projector
|
||||
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
|
||||
|
||||
# Token embedding table
|
||||
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
|
||||
|
||||
del actor_policy
|
||||
|
||||
def embed_image(self, image: Tensor) -> Tensor:
|
||||
"""Embed images using the value function's SigLIP vision tower.
|
||||
|
||||
Args:
|
||||
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
|
||||
|
||||
Returns:
|
||||
[batch_size, num_patches, hidden_dim] projected image features.
|
||||
"""
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
image_outputs = self.vision_tower(image, return_dict=True)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
image_features = image_features / (self.hidden_dim**0.5)
|
||||
|
||||
if image_features.dtype != out_dtype:
|
||||
image_features = image_features.to(out_dtype)
|
||||
return image_features
|
||||
|
||||
def embed_text(self, token_ids: Tensor) -> Tensor:
|
||||
"""Embed text token IDs using the value function's token embedding table.
|
||||
|
||||
Args:
|
||||
token_ids: [batch_size, seq_len] integer token IDs
|
||||
|
||||
Returns:
|
||||
[batch_size, seq_len, hidden_dim] text embeddings
|
||||
"""
|
||||
return self.token_embedding(token_ids)
|
||||
|
||||
def _get_cls_embedding(self, batch_size: int) -> Tensor:
|
||||
"""Get [CLS] token embedding expanded to batch size.
|
||||
|
||||
Args:
|
||||
batch_size: number of samples in the batch.
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, hidden_dim] learned [CLS] embedding.
|
||||
"""
|
||||
return self.cls_embedding.expand(batch_size, -1, -1)
|
||||
|
||||
def forward_value(
|
||||
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
|
||||
) -> dict[str, Tensor]:
|
||||
"""Core forward pass through the distributional value function.
|
||||
|
||||
Args:
|
||||
vision_features: [batch_size, num_patches, hidden_dim]
|
||||
text_embeddings: [batch_size, seq_len, hidden_dim]
|
||||
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, num_value_bins]
|
||||
probs: [batch_size, num_value_bins]
|
||||
value: [batch_size, 1]
|
||||
"""
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
device = text_embeddings.device
|
||||
|
||||
# Build sequence: [vision, text, CLS]
|
||||
cls_embedding = self._get_cls_embedding(batch_size)
|
||||
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
|
||||
|
||||
# Build causal attention mask
|
||||
vision_len = vision_features.shape[1]
|
||||
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
|
||||
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
|
||||
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
|
||||
|
||||
full_seq_len = full_padding_mask.shape[1]
|
||||
|
||||
# Causal mask
|
||||
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
|
||||
# Combine causal mask with padding mask
|
||||
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
|
||||
batch_size, 1, full_seq_len, full_seq_len
|
||||
)
|
||||
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
|
||||
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
norm_output = layer.input_layernorm(hidden_states, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
hidden_states_normed, gate = norm_output
|
||||
else:
|
||||
hidden_states_normed, gate = norm_output, None
|
||||
|
||||
input_shape = hidden_states_normed.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
attention_output, _ = modeling_gemma.eager_attention_forward(
|
||||
layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
layer.self_attn.scaling,
|
||||
)
|
||||
|
||||
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
|
||||
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
projected_attention = layer.self_attn.o_proj(attention_output)
|
||||
|
||||
if gate is not None:
|
||||
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
|
||||
else:
|
||||
projected_attention = hidden_states + projected_attention
|
||||
|
||||
after_attention_residual = projected_attention.clone()
|
||||
|
||||
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
mlp_input, gate = norm_output
|
||||
else:
|
||||
mlp_input, gate = norm_output, None
|
||||
|
||||
mlp_output = layer.mlp(mlp_input)
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
|
||||
else:
|
||||
hidden_states = after_attention_residual + mlp_output
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
# Extract [CLS] token (last position in the sequence)
|
||||
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
|
||||
|
||||
# Value head: Linear(hidden_dim, num_bins) -> logits
|
||||
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
|
||||
value_probs = F.softmax(value_logits, dim=-1)
|
||||
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
|
||||
|
||||
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
|
||||
"""HL-Gauss soft target distribution.
|
||||
|
||||
Places a Gaussian N(target, sigma^2) over the bin support and computes
|
||||
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
|
||||
|
||||
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
|
||||
Functions via Classification for Scalable Deep RL", Section 3.1.
|
||||
arXiv:2403.03950
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
# Bin edges: half a bin-width outside the first/last center
|
||||
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
|
||||
self.num_value_bins - 1
|
||||
)
|
||||
support_edges = torch.linspace(
|
||||
self.config.value_support_min - bin_width / 2,
|
||||
self.config.value_support_max + bin_width / 2,
|
||||
self.num_value_bins + 1,
|
||||
device=target_value.device,
|
||||
dtype=target_value.dtype,
|
||||
)
|
||||
|
||||
# CDF of N(target, sigma^2) evaluated at each edge
|
||||
cdf_at_edges = 0.5 * (
|
||||
1.0
|
||||
+ torch.erf(
|
||||
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
|
||||
/ (self.hl_gauss_sigma * math.sqrt(2))
|
||||
)
|
||||
) # [batch_size, num_bins + 1]
|
||||
|
||||
# Normalize: z = cdf(max_edge) - cdf(min_edge)
|
||||
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
|
||||
|
||||
# Bin probabilities = differences of consecutive CDF values, normalized
|
||||
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
|
||||
|
||||
return bin_probabilities
|
||||
|
||||
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
|
||||
"""Dirac delta (C51) projection: split probability between two nearest bins.
|
||||
|
||||
Standard distributional RL projection from Bellemare et al. 2017.
|
||||
"A Distributional Perspective on Reinforcement Learning"
|
||||
arXiv:1707.06887
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
bin_width = self.bin_centers[1] - self.bin_centers[0]
|
||||
normalized_position = (target_value - self.config.value_support_min) / bin_width
|
||||
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
|
||||
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
|
||||
|
||||
weight_upper = normalized_position - lower_bin_idx.float()
|
||||
weight_lower = upper_bin_idx.float() - normalized_position
|
||||
|
||||
same_bin = lower_bin_idx == upper_bin_idx
|
||||
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
|
||||
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
|
||||
|
||||
batch_size = target_value.shape[0]
|
||||
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
|
||||
batch_indices = torch.arange(batch_size, device=target_value.device)
|
||||
target_distribution[batch_indices, lower_bin_idx] += weight_lower
|
||||
target_distribution[batch_indices, upper_bin_idx] += weight_upper
|
||||
|
||||
return target_distribution
|
||||
|
||||
def one_hot_target(self, target_value: Tensor) -> Tensor:
|
||||
"""One-hot target for terminal states (exact return, no smoothing).
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
nearest_bin_idx = torch.argmin(
|
||||
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
|
||||
)
|
||||
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
|
||||
|
||||
def compute_target_distribution(
|
||||
self,
|
||||
target_value: Tensor,
|
||||
is_terminal: Tensor,
|
||||
method: str = "hl_gauss",
|
||||
use_one_hot_terminal: bool = True,
|
||||
) -> Tensor:
|
||||
"""Compute target distribution using configured method.
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] scalar return targets
|
||||
is_terminal: [batch_size] boolean terminal flags
|
||||
method: "hl_gauss" or "dirac_delta"
|
||||
use_one_hot_terminal: if True, terminal states get one-hot targets
|
||||
(exact return, no smoothing). If False, all states use the same method.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution
|
||||
"""
|
||||
if method == "hl_gauss":
|
||||
base_distribution = self.hl_gauss_target(target_value)
|
||||
elif method == "dirac_delta":
|
||||
base_distribution = self.dirac_delta_target(target_value)
|
||||
else:
|
||||
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
|
||||
|
||||
if not use_one_hot_terminal:
|
||||
return base_distribution
|
||||
|
||||
terminal_distribution = self.one_hot_target(target_value)
|
||||
|
||||
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — computes cross-entropy loss against MC return targets.
|
||||
|
||||
The batch is expected to be preprocessed by the processor pipeline.
|
||||
Keys expected in batch:
|
||||
- observation.images.*: [B, C, H, W] preprocessed images
|
||||
- observation.language_tokens: [B, seq_len] tokenized task prompt
|
||||
- observation.language_attention_mask: [B, seq_len] padding mask
|
||||
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
|
||||
- is_terminal: [B] boolean terminal flags
|
||||
|
||||
Returns:
|
||||
(loss, output_dict) where loss is scalar cross-entropy
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
# Get first image key from batch
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
mc_return = batch["mc_return"]
|
||||
is_terminal = batch["is_terminal"]
|
||||
|
||||
# Embed observations
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
# Forward through value function transformer
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
value_logits = vf_output["logits"]
|
||||
predicted_value = vf_output["value"]
|
||||
|
||||
# Compute target distribution
|
||||
target_distribution = self.compute_target_distribution(
|
||||
mc_return,
|
||||
is_terminal,
|
||||
method=self.config.target_method,
|
||||
use_one_hot_terminal=self.config.use_one_hot_terminal,
|
||||
)
|
||||
|
||||
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
|
||||
log_probs = F.log_softmax(value_logits, dim=-1)
|
||||
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
|
||||
|
||||
output_dict = {
|
||||
"loss": loss.item(),
|
||||
"predicted_value_mean": predicted_value.mean().item(),
|
||||
"mc_return_mean": mc_return.mean().item(),
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute V(s) for a batch of observations. Used for advantage scoring.
|
||||
|
||||
Args:
|
||||
batch: preprocessed batch with images and tokenized text
|
||||
|
||||
Returns:
|
||||
[batch_size] tensor of predicted values V(s)
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
return vf_output["value"].squeeze(-1) # [batch_size]
|
||||
+235
@@ -0,0 +1,235 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processor for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
|
||||
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
|
||||
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
|
||||
|
||||
Training targets (mc_return, is_terminal) are NOT routed through the processor.
|
||||
They are dataset columns read directly from the batch in the model's forward().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
|
||||
@dataclass
|
||||
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
|
||||
"""Format the task string for the distributional value function.
|
||||
|
||||
The value function receives only visual observations and task text.
|
||||
Builds prompt: "Task: {task}."
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
tasks = complementary_data.get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
full_prompts = []
|
||||
for task in tasks:
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
full_prompts.append(f"Task: {cleaned_text}.")
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[self.task_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
|
||||
@dataclass
|
||||
class DistributionalVFImagePreprocessorStep(ProcessorStep):
|
||||
"""Resize and normalize images for the value function's SigLIP vision tower.
|
||||
|
||||
Expects float images in [0, 1].
|
||||
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
|
||||
- Scale to [-1, 1] for SigLIP
|
||||
"""
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
image_keys: tuple[str, ...] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
|
||||
|
||||
image_keys = self.image_keys or tuple(
|
||||
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
|
||||
)
|
||||
if not image_keys:
|
||||
raise KeyError(
|
||||
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
|
||||
)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for image_key in image_keys:
|
||||
image = new_observation[image_key]
|
||||
if not isinstance(image, Tensor):
|
||||
image = to_tensor(image)
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
is_channels_first = image.ndim == 4 and image.shape[1] == 3
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != self.image_resolution:
|
||||
image = resize_with_pad_torch(image, *self.image_resolution)
|
||||
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
|
||||
new_observation[image_key] = image
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"image_resolution": self.image_resolution,
|
||||
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
|
||||
return tuple(
|
||||
feature_name
|
||||
for feature_name, feature in config.input_features.items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
)
|
||||
|
||||
|
||||
def make_distributional_vf_pre_post_processors(
|
||||
config: DistributionalVFConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre/post processors for the distributional value function.
|
||||
|
||||
Preprocessor steps:
|
||||
1. Rename observations (no-op by default)
|
||||
2. Add a batch dimension
|
||||
3. Normalize features (images use identity, so they stay in [0, 1])
|
||||
4. Format task prompt: "Task: {task}."
|
||||
5. Tokenize with the PaliGemma tokenizer
|
||||
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
|
||||
7. Move tensors to the configured device
|
||||
|
||||
Training targets (mc_return, is_terminal) are not processed here.
|
||||
The model reads them directly from the batch in forward().
|
||||
|
||||
The postprocessor is a no-op because the value function does not need
|
||||
action postprocessing.
|
||||
"""
|
||||
image_keys = _visual_image_keys(config)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DistributionalVFPrepareTaskPromptStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DistributionalVFImagePreprocessorStep(
|
||||
image_resolution=config.image_resolution,
|
||||
image_keys=image_keys or None,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
elif name == "distributional_value_function":
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
elif reward_type == "distributional_value_function":
|
||||
return DistributionalVFConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, DistributionalVFConfig):
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
|
||||
return make_distributional_vf_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-annotate`` — populate ``language_persistent`` and
|
||||
``language_events`` columns on a LeRobot dataset.
|
||||
|
||||
Annotations live directly in ``data/chunk-*/file-*.parquet``.
|
||||
|
||||
Example:
|
||||
|
||||
uv run lerobot-annotate \\
|
||||
--root=/path/to/dataset \\
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
For distributed runs, see ``examples/annotations/run_hf_job.py``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs import parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
|
||||
if cfg.root is not None:
|
||||
return Path(cfg.root)
|
||||
if cfg.repo_id is not None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
|
||||
raise ValueError("Either --root or --repo_id must be provided.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Run the steerable annotation pipeline against a dataset."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
root = _resolve_root(cfg)
|
||||
logger.info("annotate: root=%s", root)
|
||||
|
||||
vlm = make_vlm_client(cfg.vlm)
|
||||
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend)
|
||||
# Surface the resolved cameras up front so a silent vqa-module no-op
|
||||
# is obvious in job output rather than discovered post-hoc by counting
|
||||
# parquet rows.
|
||||
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
|
||||
logger.info(
|
||||
"annotate: frame_provider default camera=%r, all cameras=%s",
|
||||
getattr(frame_provider, "camera_key", None),
|
||||
cam_keys,
|
||||
)
|
||||
if cfg.vqa.enabled and not cam_keys:
|
||||
logger.warning(
|
||||
"annotate: the vqa module is enabled but no cameras were "
|
||||
"resolved — it will produce zero VQA rows. Check "
|
||||
"meta/info.json for observation.images.* features, or pass "
|
||||
"--vlm.camera_key=<key> to seed the cameras list."
|
||||
)
|
||||
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
|
||||
interjections = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
plan=plan,
|
||||
interjections=interjections,
|
||||
vqa=vqa,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
summary = executor.run(root)
|
||||
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
|
||||
for phase in summary.phases:
|
||||
logger.info(
|
||||
"annotate: phase=%s processed=%d skipped=%d",
|
||||
phase.name,
|
||||
phase.episodes_processed,
|
||||
phase.episodes_skipped,
|
||||
)
|
||||
if summary.validation_report.warnings:
|
||||
for w in summary.validation_report.warnings:
|
||||
logger.warning(w)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
if cfg.repo_id is None and cfg.new_repo_id is None:
|
||||
raise ValueError(
|
||||
"--push_to_hub requires --repo_id or --new_repo_id (the dataset repo to push to)."
|
||||
)
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
|
||||
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Upload the annotated dataset directory to the Hub.
|
||||
|
||||
Pushes to ``cfg.new_repo_id`` when set, otherwise back to ``cfg.repo_id``.
|
||||
"""
|
||||
from huggingface_hub import HfApi # noqa: PLC0415
|
||||
|
||||
repo_id = cfg.new_repo_id or cfg.repo_id
|
||||
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||
api = HfApi()
|
||||
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||
api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
private=cfg.push_private,
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
|
||||
commit_info = api.upload_folder(
|
||||
folder_path=str(root),
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
|
||||
)
|
||||
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||
|
||||
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
|
||||
# resolves the dataset revision via ``get_safe_version`` which scans
|
||||
# for tags like ``v3.0``; without a tag it raises
|
||||
# ``RevisionNotFoundError``. Read the version straight from the
|
||||
# dataset's own ``meta/info.json`` so we tag whatever the writer
|
||||
# actually wrote (no accidental drift if the codebase floor moves).
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
version_tag = CODEBASE_VERSION
|
||||
if info_path.exists():
|
||||
try:
|
||||
from lerobot.utils.io_utils import load_json # noqa: PLC0415
|
||||
|
||||
info = load_json(info_path)
|
||||
ds_version = info.get("codebase_version")
|
||||
if isinstance(ds_version, str) and ds_version.startswith("v"):
|
||||
version_tag = ds_version
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}",
|
||||
flush=True,
|
||||
)
|
||||
revision = getattr(commit_info, "oid", None)
|
||||
tag_kwargs = {
|
||||
"repo_id": repo_id,
|
||||
"tag": version_tag,
|
||||
"repo_type": "dataset",
|
||||
}
|
||||
if revision is not None:
|
||||
tag_kwargs["revision"] = revision
|
||||
|
||||
try:
|
||||
from contextlib import suppress # noqa: PLC0415
|
||||
|
||||
from huggingface_hub.errors import RevisionNotFoundError # noqa: PLC0415
|
||||
|
||||
with suppress(RevisionNotFoundError):
|
||||
api.delete_tag(repo_id, tag=version_tag, repo_type="dataset")
|
||||
api.create_tag(**tag_kwargs)
|
||||
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
|
||||
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
|
||||
"Run: from huggingface_hub import HfApi; "
|
||||
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
annotate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -94,14 +94,6 @@ Merge multiple datasets from a list of local dataset paths:
|
||||
--operation.repo_ids "['pusht_train', 'pusht_val']" \
|
||||
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
|
||||
|
||||
Merge multiple datasets while keeping one file per source file (no video/data stitching):
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" \
|
||||
--operation.concatenate_videos false \
|
||||
--operation.concatenate_data false
|
||||
|
||||
Remove camera feature:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
@@ -265,9 +257,6 @@ class SplitConfig(OperationConfig):
|
||||
class MergeConfig(OperationConfig):
|
||||
repo_ids: list[str] | None = None
|
||||
roots: list[str] | None = None
|
||||
# When False, keep one file per source file instead of packing into shards.
|
||||
concatenate_videos: bool = True
|
||||
concatenate_data: bool = True
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("remove_feature")
|
||||
@@ -472,8 +461,6 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
datasets,
|
||||
output_repo_id=cfg.new_repo_id,
|
||||
output_dir=output_dir,
|
||||
concatenate_videos=cfg.operation.concatenate_videos,
|
||||
concatenate_data=cfg.operation.concatenate_data,
|
||||
)
|
||||
|
||||
logging.info(f"Merged dataset saved to {output_dir}")
|
||||
|
||||
@@ -72,9 +72,8 @@ from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.configs import FeatureType, parser
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs import (
|
||||
check_env_attributes_and_types,
|
||||
close_envs,
|
||||
@@ -85,7 +84,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -96,81 +95,6 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
|
||||
|
||||
If raw_obs is provided, visual feature shapes are inferred from the actual observation
|
||||
to avoid mismatches between the env config and the real observation resolution.
|
||||
"""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
elif raw_obs is not None and "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
|
||||
shape = img.shape[1:] # strip batch dim
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
shape = pixels.shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
return features
|
||||
|
||||
|
||||
def _build_raw_frame(
|
||||
raw_obs: dict,
|
||||
env_idx: int,
|
||||
action: np.ndarray,
|
||||
reward: float,
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
raw_key = key
|
||||
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
|
||||
val = raw_obs[raw_key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -181,7 +105,6 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -222,14 +145,6 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
raw_observation = deepcopy(observation) if recording_dataset is not None else None
|
||||
task_desc = ""
|
||||
if recording_dataset is not None:
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
task_desc = ""
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
@@ -302,26 +217,6 @@ def rollout(
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
if recording_dataset is not None and raw_observation is not None:
|
||||
prev_done = done.copy()
|
||||
for env_idx in range(env.num_envs):
|
||||
if prev_done[env_idx]:
|
||||
continue
|
||||
frame = _build_raw_frame(
|
||||
raw_observation,
|
||||
env_idx,
|
||||
action_numpy[env_idx],
|
||||
reward[env_idx],
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_dataset.features,
|
||||
)
|
||||
recording_dataset.add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_dataset.save_episode()
|
||||
raw_observation = deepcopy(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||
@@ -378,7 +273,6 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -467,7 +361,6 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -670,10 +563,6 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
|
||||
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
@@ -683,13 +572,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=False,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
recording_dir=recording_dir,
|
||||
env_features=cfg.env.features if cfg.eval.recording else None,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
@@ -732,7 +618,6 @@ def eval_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> TaskMetrics:
|
||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||
|
||||
@@ -750,7 +635,6 @@ def eval_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -777,8 +661,6 @@ def run_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Run eval_one for a single (task_group, task_id, env).
|
||||
@@ -790,39 +672,21 @@ def run_one(
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
recording_dataset = None
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
sample_obs, _ = env.reset()
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
|
||||
recording_dataset = LeRobotDataset.create(
|
||||
repo_id=f"eval_{task_group}_{task_id}",
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=str(task_recording_dir),
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
finally:
|
||||
if recording_dataset is not None:
|
||||
recording_dataset.finalize()
|
||||
|
||||
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
)
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
@@ -838,8 +702,6 @@ def eval_policy_all(
|
||||
n_episodes: int,
|
||||
*,
|
||||
max_episodes_rendered: int = 0,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
@@ -899,8 +761,6 @@ def eval_policy_all(
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
|
||||
@@ -36,8 +36,6 @@ from tqdm import tqdm
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
@@ -45,7 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -101,9 +99,6 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
@@ -163,8 +158,6 @@ def update_policy(
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
if torch.cuda.is_available():
|
||||
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@@ -239,16 +232,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: the global main process downloads once to the shared
|
||||
# dataset root, then a barrier lets every other rank read the already-populated copy.
|
||||
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Other ranks read from the shared copy populated by the main process.
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
@@ -393,47 +384,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if not cfg.dataset.streaming:
|
||||
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
|
||||
# The order is a pure function of (seed, epoch), so every rank independently produces the
|
||||
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
|
||||
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||
# use the values recorded in the checkpoint (falling back to the current ones for older
|
||||
# ckpts that did not store them).
|
||||
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
|
||||
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
|
||||
ckpt_num_processes = saved_num_processes or accelerator.num_processes
|
||||
ckpt_batch_size = saved_batch_size or cfg.batch_size
|
||||
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
|
||||
logging.warning(
|
||||
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
|
||||
f"written with num_processes={saved_num_processes}. The data order resumes at the "
|
||||
"right epoch/offset, but per-rank sample-exactness requires the same world size."
|
||||
)
|
||||
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
|
||||
logging.warning(
|
||||
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
|
||||
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
|
||||
"but per-rank sample-exactness requires the same batch size."
|
||||
)
|
||||
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
|
||||
sampler.load_state_dict(sampler_state)
|
||||
if is_main_process:
|
||||
logging.info(
|
||||
f"Resuming data order at epoch {sampler_state['epoch']}, "
|
||||
f"sample {sampler_state['start_index']}"
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -465,22 +424,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
|
||||
# is actually optimizing. grad_norm and lr are already identical on every rank (post
|
||||
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
|
||||
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
|
||||
# true straggler instead of rank 0's view.
|
||||
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
|
||||
# Derived from the post-reduce max step time; set once per log window on the main rank.
|
||||
"samples_per_s": AverageMeter("smp/s", ":.0f"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
# max() because headroom is gated by the worst-case rank.
|
||||
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
|
||||
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
@@ -532,29 +481,21 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if is_log_step:
|
||||
# Collective reduce must run on every rank, before the main-process gate below.
|
||||
train_tracker.reduce_across_ranks()
|
||||
if is_main_process:
|
||||
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
|
||||
# reflects the slowest rank — which is what actually gates the next iteration.
|
||||
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
|
||||
if step_time > 0:
|
||||
train_tracker.samples_per_s = effective_batch_size / step_time
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
@@ -570,8 +511,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
batch_size=cfg.batch_size,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
@@ -13,213 +13,77 @@
|
||||
[SmolVLA](https://huggingface.co/papers/2506.01844) is a compact, efficient vision-language-action model that achieves competitive performance at reduced computational costs and can be deployed on consumer-grade hardware.
|
||||
{% elif model_name == "act" %}
|
||||
[Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates.
|
||||
{% elif model_name == "tdmpc" %}
|
||||
[TD-MPC](https://huggingface.co/papers/2203.04955) combines model-free and model-based approaches to improve sample efficiency and performance in continuous control tasks by using a learned latent dynamics model and terminal value function.
|
||||
{% elif model_name == "diffusion" %}
|
||||
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
|
||||
{% elif model_name == "vqbet" %}
|
||||
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
|
||||
{% elif model_name == "pi0" %}
|
||||
[π₀ (Pi0)](https://www.physicalintelligence.company/blog/pi0) is a general-purpose robot foundation model from Physical Intelligence: a generalist Vision-Language-Action policy that understands visual inputs, interprets natural language instructions, and controls a variety of different robots across diverse tasks. The LeRobot implementation is adapted from their open-source OpenPI repository.
|
||||
**π₀ (Pi0)**
|
||||
|
||||
π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
|
||||
|
||||
**Model Overview**
|
||||
|
||||
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
|
||||
|
||||
For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0).
|
||||
{% elif model_name == "pi05" %}
|
||||
[π₀.₅ (Pi05)](https://www.physicalintelligence.company/blog/pi05) is a Vision-Language-Action model from Physical Intelligence designed for open-world generalization: it evolves π₀ to generalize to entirely new environments and situations that were never seen during training. The LeRobot implementation is adapted from their open-source OpenPI repository.
|
||||
{% elif model_name == "molmoact2" %}
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2) is an open robotics foundation model from the Allen Institute for AI (Ai2) that maps camera images and language instructions to robot action chunks. The LeRobot implementation supports training and evaluation of the regular MolmoAct2 model.
|
||||
{% elif model_name == "vla_jepa" %}
|
||||
[VLA-JEPA](https://arxiv.org/abs/2602.10098) is a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
**π₀.₅ (Pi05) Policy**
|
||||
|
||||
π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
|
||||
|
||||
**Model Overview**
|
||||
|
||||
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
|
||||
|
||||
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
|
||||
{% elif model_name == "gaussian_actor" %}
|
||||
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
|
||||
{% elif model_name == "pi0_fast" %}
|
||||
[π₀-FAST (Pi0-FAST)](https://www.physicalintelligence.company/research/fast) is a Vision-Language-Action model for general robot control, from Physical Intelligence. It models continuous robot actions with autoregressive next-token prediction using FAST (Frequency-space Action Sequence Tokenization), training up to 5x faster than diffusion-based π₀.
|
||||
{% elif model_name == "eo1" %}
|
||||
[EO-1](https://huggingface.co/papers/2508.21112) is a Vision-Language-Action model for general robot control. It pairs a Qwen2.5-VL backbone for vision-language understanding with a continuous flow-matching action head that denoises action chunks.
|
||||
{% elif model_name == "groot" %}
|
||||
[GR00T N1.5](https://github.com/NVIDIA/Isaac-GR00T) is an open, cross-embodiment foundation model from NVIDIA for generalized humanoid robot reasoning and skills. It takes language and images as input and uses a flow-matching action transformer to predict actions conditioned on vision, language, and proprioception.
|
||||
{% elif model_name == "multi_task_dit" %}
|
||||
[Multi-Task Diffusion Transformer (DiT)](https://huggingface.co/papers/2507.05331) extends Diffusion Policy with a large Diffusion Transformer and text + vision conditioning for multi-task robot learning. It supports both diffusion and flow-matching objectives and reaches high dexterity with only ~450M parameters.
|
||||
{% elif model_name == "wall_x" %}
|
||||
[WALL-OSS](https://huggingface.co/papers/2509.11766) is an open-source foundation model for embodied intelligence from XSquare Robot. Built on Qwen2.5-VL, it uses a tightly-coupled multimodal architecture with flow matching to unify semantic reasoning and high-frequency action generation for cross-embodiment control.
|
||||
{% elif model_name == "xvla" %}
|
||||
[X-VLA](https://huggingface.co/papers/2510.10274) is a soft-prompted, flow-matching Vision-Language-Action framework that treats each robot or hardware setup as a "task" encoded with a small set of learnable Soft Prompt embeddings, letting a single model reconcile diverse robot morphologies, sensors, and action spaces.
|
||||
{% else %}
|
||||
This is a **{{ model_name }}** policy trained with [LeRobot](https://github.com/huggingface/lerobot).
|
||||
_Model type not recognized — please update this template._
|
||||
{% endif %}
|
||||
{% set diagrams = {
|
||||
"smolvla": "https://cdn-uploads.huggingface.co/production/uploads/640e21ef3c82bd463ee5a76d/aooU0a3DMtYmy_1IWMaIM.png",
|
||||
"pi0": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pi0%20(1).png",
|
||||
"pi0_fast": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pifast.png",
|
||||
"eo1": "https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png",
|
||||
"groot": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png",
|
||||
"wall_x": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/walloss-lerobot-paper.png",
|
||||
"xvla": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
|
||||
} %}
|
||||
{% if diagrams.get(model_name) %}
|
||||
<p align="center">
|
||||
<img src="{{ diagrams[model_name] }}" alt="{{ model_name }} architecture" width="85%"/>
|
||||
</p>
|
||||
{% endif %}
|
||||
|
||||
<!-- A short demo is worth more than any description! Record a GIF/video of the policy
|
||||
running on your robot, upload it to this repo, and embed it here:
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/<hf_user>/<policy_repo_id>/resolve/main/demo.gif" width="60%"/>
|
||||
</p>
|
||||
-->
|
||||
|
||||
This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
{% set policy_docs = {
|
||||
"act": "act",
|
||||
"smolvla": "smolvla",
|
||||
"pi0": "pi0",
|
||||
"pi0_fast": "pi0fast",
|
||||
"pi05": "pi05",
|
||||
"molmoact2": "molmoact2",
|
||||
"vla_jepa": "vla_jepa",
|
||||
"eo1": "eo1",
|
||||
"groot": "groot",
|
||||
"xvla": "xvla",
|
||||
"multi_task_dit": "multi_task_dit",
|
||||
"wall_x": "walloss"
|
||||
} %}
|
||||
{% if policy_docs.get(model_name) %}Learn how to train and run it in the [LeRobot {{ model_name }} guide](https://huggingface.co/docs/lerobot/main/en/{{ policy_docs[model_name] }}), or browse the [full documentation](https://huggingface.co/docs/lerobot/index).
|
||||
{% else %}See the [full LeRobot documentation](https://huggingface.co/docs/lerobot/index).
|
||||
{% endif %}
|
||||
See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
|
||||
|
||||
---
|
||||
|
||||
## How to Get Started with the Model
|
||||
|
||||
For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy).
|
||||
Below is the short version on how to train and run inference/eval:
|
||||
|
||||
### Train from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/<desired_policy_repo_id> \
|
||||
--job_name=lerobot_training \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/<desired_policy_repo_id>
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
_Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
|
||||
|
||||
### Evaluate the policy/run inference
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--dataset.repo_id=<hf_user>/eval_<dataset> \
|
||||
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
||||
--episodes=10
|
||||
```
|
||||
|
||||
Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Model Details
|
||||
|
||||
- **License:** {{ license | default("\[More Information Needed]", true) }}
|
||||
{% if base_model %}- **Fine-tuned from:** [{{ base_model }}](https://huggingface.co/{{ base_model }})
|
||||
{% endif %}{% if robot_type %}- **Robot type:** `{{ robot_type }}`
|
||||
{% endif %}{% if cameras %}- **Cameras:** {% for camera in cameras %}`{{ camera }}`{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
{% endif %}
|
||||
{% if input_features or output_features %}
|
||||
## Inputs & Outputs
|
||||
|
||||
The policy consumes these observation features and produces these action features.
|
||||
{% if input_features %}
|
||||
**Inputs**
|
||||
|
||||
| Feature | Type | Shape |
|
||||
| --- | --- | --- |
|
||||
{% for name, feature in input_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
|
||||
{% endfor %}{% endif %}{% if output_features %}
|
||||
**Outputs**
|
||||
|
||||
| Feature | Type | Shape |
|
||||
| --- | --- | --- |
|
||||
{% for name, feature in output_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
|
||||
{% endfor %}{% endif %}{% endif %}
|
||||
{% if dataset %}
|
||||
## Training Dataset
|
||||
|
||||
- **Repository:** [{{ dataset.repo_id }}](https://huggingface.co/datasets/{{ dataset.repo_id }})
|
||||
- **Episodes:** {{ dataset.episodes }}
|
||||
- **Frames:** {{ dataset.frames }}
|
||||
- **Frame rate:** {{ dataset.fps }} FPS
|
||||
{% if dataset.tasks %}- **Task(s):** {% for task in dataset.tasks %}"{{ task }}"{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
{% endif %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ dataset.repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
{% if training %}
|
||||
## Training Configuration
|
||||
|
||||
| Setting | Value |
|
||||
| --- | --- |
|
||||
| Training steps | {{ training.steps }} |
|
||||
| Batch size | {{ training.batch_size }} |
|
||||
{% if training.optimizer %}| Optimizer | {{ training.optimizer }} |
|
||||
{% endif %}{% if training.lr %}| Learning rate | {{ training.lr }} |
|
||||
{% endif %}{% if training.seed is not none %}| Seed | {{ training.seed }} |
|
||||
{% endif %}| LeRobot version | {{ training.lerobot_version }} |
|
||||
{% endif %}
|
||||
---
|
||||
|
||||
## How to Get Started with the Model
|
||||
|
||||
New to LeRobot? These guides cover the full workflow:
|
||||
|
||||
- **[Install LeRobot](https://huggingface.co/docs/lerobot/main/en/installation)** — set up the `lerobot` package.
|
||||
- **[Hardware setup](https://huggingface.co/docs/lerobot/main/en/hardware_guide)** — assemble, wire, and calibrate your robot and cameras.
|
||||
- **[Record data & train a policy](https://huggingface.co/docs/lerobot/en/il_robots)** — the end-to-end imitation-learning walkthrough.
|
||||
- **[CLI cheat-sheet](https://huggingface.co/docs/lerobot/main/en/cheat-sheet)** — quick reference for the `lerobot-*` commands.
|
||||
|
||||
The short version to run and train this policy:
|
||||
|
||||
### Run the policy on your robot
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type={{ robot_type | default("<your_robot_type>", true) }} \
|
||||
--robot.port=<your_robot_port> \
|
||||
--robot.cameras="{ <camera_1>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}, <camera_2>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}}" \
|
||||
--policy.path={{ policy_repo_id | default("<hf_user>/<policy_repo_id>", true) }} \
|
||||
--task="{% if dataset and dataset.tasks %}{{ dataset.tasks[0] }}{% else %}<your_task_description>{% endif %}" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
Replace the remaining `<...>` placeholders with your own values: `--robot.port` and the camera names/indices are specific to your machine, and the camera names must match the observation keys this policy was trained on.
|
||||
|
||||
When `--strategy.type=base` is used the script doesn't record the episodes. Skipping duration will make the policy run indefinitely. For more information look at [rollout documentation](https://huggingface.co/docs/lerobot/main/en/inference).
|
||||
|
||||
{% if base_model %}### Train your own policy
|
||||
|
||||
This policy type is usually fine-tuned from the pretrained base model [{{ base_model }}](https://huggingface.co/{{ base_model }}):
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.path={{ base_model }} \
|
||||
--output_dir=outputs/train/<policy_repo_id> \
|
||||
--job_name=lerobot_training \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/<policy_repo_id> \
|
||||
--wandb.enable=true
|
||||
```
|
||||
{% else %}### Train your own policy
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/<dataset> \
|
||||
--policy.type={{ model_name }} \
|
||||
--output_dir=outputs/train/<policy_repo_id> \
|
||||
--job_name=lerobot_training \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/<policy_repo_id> \
|
||||
--wandb.enable=true
|
||||
```
|
||||
{% endif %}
|
||||
_Writes checkpoints to `outputs/train/<policy_repo_id>/checkpoints/`._
|
||||
|
||||
---
|
||||
|
||||
## Evaluation
|
||||
|
||||
<!-- Report real-robot results here: run the policy several times per task and count the
|
||||
successes. Delete the "No evaluation results" line and fill in this table instead:
|
||||
|
||||
| Task | Trials | Successes | Success rate |
|
||||
| ---- | ------ | --------- | ------------ |
|
||||
| pick the lego brick | 10 | 8 | 80% |
|
||||
|
||||
Also worth noting: anything that affects difficulty (new object positions, lighting,
|
||||
distractors, a different robot of the same type, ...).
|
||||
-->
|
||||
|
||||
_No evaluation results have been provided for this policy yet._
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this policy, please cite the method linked in the description above, along with LeRobot:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -13,39 +13,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import format_big_number
|
||||
|
||||
_VALID_REDUCTIONS = ("none", "max", "mean", "sum")
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
||||
|
||||
Args:
|
||||
name: Display name of the metric.
|
||||
fmt: Format string used when rendering the metric.
|
||||
reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks`
|
||||
before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``,
|
||||
or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or
|
||||
update wall time) so multi-GPU runs report the slowest rank rather than rank 0.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"):
|
||||
if reduction not in _VALID_REDUCTIONS:
|
||||
raise ValueError(
|
||||
f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}."
|
||||
)
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reduction = reduction
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -156,37 +138,6 @@ class MetricsTracker:
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def reduce_across_ranks(self) -> None:
|
||||
"""
|
||||
Synchronises the running averages of every metric whose ``reduction`` is not ``"none"``
|
||||
across all distributed processes (in-place).
|
||||
|
||||
This is a collective operation and MUST be invoked on every rank — typically just before
|
||||
logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics
|
||||
reported by the main process only reflect rank 0; for bottleneck-style timings
|
||||
(``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible.
|
||||
"""
|
||||
if self.accelerator is None or self.accelerator.num_processes <= 1:
|
||||
return
|
||||
|
||||
buckets: dict[str, list[str]] = defaultdict(list)
|
||||
for name, meter in self.metrics.items():
|
||||
if meter.reduction != "none":
|
||||
buckets[meter.reduction].append(name)
|
||||
if not buckets:
|
||||
return
|
||||
|
||||
device = self.accelerator.device
|
||||
for reduction, names in buckets.items():
|
||||
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
|
||||
reduced = self.accelerator.reduce(tensor, reduction=reduction)
|
||||
for name, value in zip(names, reduced.tolist(), strict=True):
|
||||
meter = self.metrics[name]
|
||||
# Preserve avg == sum / count so a later .update() on this meter accumulates
|
||||
# against the cluster view, not the stale per-rank history.
|
||||
meter.avg = value
|
||||
meter.sum = value * meter.count
|
||||
|
||||
def __str__(self) -> str:
|
||||
display_list = [
|
||||
f"step:{format_big_number(self.steps)}",
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helpers shared across annotation-pipeline tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
|
||||
def make_canned_responder(
|
||||
responses_by_marker: dict[str, Any],
|
||||
default: Any = None,
|
||||
) -> StubVlmClient:
|
||||
"""Return a stub that picks a response by inspecting the user prompt.
|
||||
|
||||
For each call the responder examines the last user-message text and
|
||||
returns the response keyed by the first marker substring it contains.
|
||||
Falls back to ``default`` if no marker matches.
|
||||
"""
|
||||
|
||||
def responder(messages: list[dict[str, Any]]) -> Any:
|
||||
last_user_text = ""
|
||||
for message in messages:
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
last_user_text = content
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
last_user_text = block.get("text", "")
|
||||
for marker, response in responses_by_marker.items():
|
||||
if marker in last_user_text:
|
||||
return response
|
||||
return default
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def encode_vqa_answer(payload: dict[str, Any]) -> str:
|
||||
return json.dumps(payload, sort_keys=True)
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared fixtures for annotation-pipeline tests.
|
||||
|
||||
The on-disk dataset builder lives with the other dataset factories in
|
||||
``tests/fixtures/dataset_factories.py`` (:func:`build_annotation_dataset`);
|
||||
these fixtures only wire it into pytest.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``build_annotation_dataset`` pulls in ``lerobot.datasets`` (HF ``datasets``
|
||||
# + ``pandas``, only in the ``dataset`` extra), so it's imported lazily inside
|
||||
# each fixture — this conftest stays importable without that extra. The test
|
||||
# modules ``pytest.importorskip("datasets")`` so they skip rather than error.
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds",
|
||||
episode_specs=[
|
||||
(0, 12, "Could you tidy the kitchen please?"),
|
||||
(1, 12, "Please clean up the kitchen"),
|
||||
],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_episode_root(tmp_path: Path) -> Path:
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
return build_annotation_dataset(
|
||||
tmp_path / "ds_one",
|
||||
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
@@ -1,116 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Opt-in E2E smoke run for ``make annotation-e2e``.
|
||||
|
||||
Builds the shared annotation fixture (:func:`build_annotation_dataset`),
|
||||
runs the full annotation pipeline against it with a stub VLM, and prints a
|
||||
short report. This is intentionally not a pytest test — it exercises the
|
||||
CLI plumbing — but it reuses the same on-disk dataset builder as the pytest
|
||||
fixtures so there is no duplicated fixture code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from tests.fixtures.dataset_factories import build_annotation_dataset
|
||||
|
||||
|
||||
def _stub_responder(messages):
|
||||
text = ""
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
content = m.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
text = content
|
||||
if "atomic subtasks" in text:
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 1.0},
|
||||
{"text": "pour into the cup", "start": 1.0, "end": 2.0},
|
||||
{"text": "place the bottle down", "start": 2.0, "end": 3.0},
|
||||
]
|
||||
}
|
||||
if "compressed semantic memory" in text:
|
||||
return {"memory": "poured once"}
|
||||
if "acknowledgement the robot" in text:
|
||||
return {"text": "Sure."}
|
||||
if "compact interjection" in text:
|
||||
return {"interjection": "use less water", "speech": "Using less water."}
|
||||
if "frame-grounded visual question" in text:
|
||||
return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}}
|
||||
return None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = build_annotation_dataset(
|
||||
Path(tmp) / "ds",
|
||||
episode_specs=[(0, 30, "Pour water into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
vlm = StubVlmClient(responder=_stub_responder)
|
||||
cfg = AnnotationPipelineConfig()
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.interjections, seed=cfg.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
summary = executor.run(root)
|
||||
print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}")
|
||||
print(f"validation: {summary.validation_report.summary()}")
|
||||
print(f"shards rewritten: {len(summary.written_paths)}")
|
||||
|
||||
# Assert the interjection code path actually fired — otherwise a stale
|
||||
# canned-VLM marker would silently produce zero interjections and this
|
||||
# smoke run would still "pass" by only printing.
|
||||
import pyarrow.parquet as pq # noqa: PLC0415
|
||||
|
||||
events = [
|
||||
r
|
||||
for shard in summary.written_paths
|
||||
for ev in pq.read_table(shard).column("language_events").to_pylist()
|
||||
for r in ev
|
||||
]
|
||||
n_interjections = sum(1 for r in events if r.get("style") == "interjection")
|
||||
n_speech = sum(1 for r in events if r.get("style") is None and r.get("role") == "assistant")
|
||||
print(f"interjections={n_interjections} speech_atoms={n_speech}")
|
||||
assert n_interjections > 0, "no interjection rows produced — check the interjection prompt marker"
|
||||
assert n_speech > 0, "no speech tool-call atoms produced — check the speech prompt marker"
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,246 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for :class:`VideoFrameProvider` method bindings.
|
||||
|
||||
These were prompted by a real regression: ``video_for_episode`` was once
|
||||
indented one level too deep so it ended up nested *inside* a module-level
|
||||
helper (after that function's ``return`` statement) — silently dead code
|
||||
that meant production runs with ``use_video_url=False`` would
|
||||
``AttributeError`` on ``self.frame_provider.video_for_episode(...)``. The
|
||||
existing module tests didn't catch it because they exercise stub providers.
|
||||
|
||||
The tests below assert on the class itself (not on an instance), so a
|
||||
future reindent regression flips them to red without needing a real
|
||||
LeRobot dataset on disk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.frames import VideoFrameProvider # noqa: E402
|
||||
|
||||
|
||||
class _FakeMeta:
|
||||
"""Minimal metadata stub exposing ``video_keys`` / ``camera_keys``."""
|
||||
|
||||
def __init__(self, video_keys: list[str], image_keys: list[str], video_path: Path | None = None) -> None:
|
||||
self.video_keys = video_keys
|
||||
self.camera_keys = [*video_keys, *image_keys]
|
||||
self._video_path = video_path
|
||||
self.episodes = {0: {f"videos/{key}/from_timestamp": 0.0 for key in video_keys}}
|
||||
|
||||
def get_video_file_path(self, episode_index: int, camera_key: str) -> Path:
|
||||
return self._video_path
|
||||
|
||||
|
||||
def test_default_camera_key_skips_image_only_cameras(tmp_path: Path, monkeypatch) -> None:
|
||||
"""The default camera must be a *video* key — image-stored cameras have no
|
||||
``videos/<key>/from_timestamp`` and would KeyError in the clip/decode path.
|
||||
|
||||
Regression: a dataset whose first ``camera_keys`` entry was an image-stored
|
||||
camera (e.g. ``observation.images.wrist``) crashed at clip extraction.
|
||||
"""
|
||||
fake = _FakeMeta(
|
||||
video_keys=["observation.images.robot0_agentview_right"],
|
||||
image_keys=["observation.images.wrist"],
|
||||
)
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
provider = VideoFrameProvider(root=tmp_path)
|
||||
assert provider.camera_key == "observation.images.robot0_agentview_right"
|
||||
assert "observation.images.wrist" not in provider.camera_keys
|
||||
|
||||
|
||||
def test_video_for_episode_is_a_method_of_videoframeprovider():
|
||||
"""``video_for_episode`` must be a bound method, not nested dead code."""
|
||||
assert callable(getattr(VideoFrameProvider, "video_for_episode", None))
|
||||
|
||||
|
||||
def test_episode_clip_path_is_a_method_of_videoframeprovider():
|
||||
"""``episode_clip_path`` is now a method (was a free function reaching
|
||||
into ``provider._meta`` from outside the class)."""
|
||||
assert callable(getattr(VideoFrameProvider, "episode_clip_path", None))
|
||||
|
||||
|
||||
def test_videoframeprovider_has_a_lock_for_concurrent_use():
|
||||
"""A ``ThreadPoolExecutor`` runs the plan / interjections / vqa phases
|
||||
concurrently; the cache + warn-flag accesses must be guarded.
|
||||
"""
|
||||
import threading
|
||||
|
||||
# Fresh-instance check via a minimal fake to avoid touching the hub.
|
||||
# The lock is declared with ``init=False`` and has a default factory,
|
||||
# so a constructed instance must own a real ``threading.Lock``.
|
||||
lock_field = next(
|
||||
(f for f in VideoFrameProvider.__dataclass_fields__.values() if f.name == "_lock"),
|
||||
None,
|
||||
)
|
||||
assert lock_field is not None
|
||||
assert lock_field.default_factory is threading.Lock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video(tmp_path: Path) -> Path:
|
||||
"""A 3 s 10 fps test-pattern mp4, written with ffmpeg."""
|
||||
if shutil.which("ffmpeg") is None:
|
||||
pytest.skip("ffmpeg not available")
|
||||
out = tmp_path / "sample.mp4"
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f",
|
||||
"lavfi",
|
||||
"-i",
|
||||
"testsrc=duration=3:size=160x120:rate=10",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
str(out),
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _provider_for_video(tmp_path: Path, video: Path, monkeypatch) -> VideoFrameProvider:
|
||||
"""A provider whose single camera resolves to ``video`` via fake metadata."""
|
||||
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=video)
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
return VideoFrameProvider(root=tmp_path, tolerance_s=0.2)
|
||||
|
||||
|
||||
def test_decode_returns_one_uint8_frame_per_timestamp(
|
||||
sample_video: Path, tmp_path: Path, monkeypatch
|
||||
) -> None:
|
||||
"""``_decode`` routes through ``decode_video_frames`` (torchcodec when
|
||||
available, PyAV otherwise) — no subprocess fallback.
|
||||
"""
|
||||
provider = _provider_for_video(tmp_path, sample_video, monkeypatch)
|
||||
timestamps = [0.0, 1.0, 2.5]
|
||||
frames = provider._decode(0, timestamps, "observation.images.cam")
|
||||
|
||||
assert len(frames) == len(timestamps)
|
||||
for frame in frames:
|
||||
assert isinstance(frame, torch.Tensor)
|
||||
assert frame.dtype == torch.uint8
|
||||
assert frame.shape == (3, 120, 160)
|
||||
|
||||
|
||||
def test_frames_at_snaps_mid_frame_grid_to_real_frames(
|
||||
sample_video: Path, tmp_path: Path, monkeypatch
|
||||
) -> None:
|
||||
"""Uniform sampling grids land mid-frame; ``frames_at`` must snap them to
|
||||
real frame timestamps before decoding.
|
||||
|
||||
Regression: ``decode_video_frames`` rejects queries farther than
|
||||
``tolerance_s`` (default 10 ms) from a decodable frame, so un-snapped
|
||||
mid-frame queries raised ``FrameTimestampError`` wholesale and the plan
|
||||
module silently lost its contact sheets for most episodes.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=sample_video)
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
provider = VideoFrameProvider(root=tmp_path) # default 10 ms tolerance
|
||||
# 10 fps fixture -> frames at 0.0, 0.1, ...; queries sit mid-frame.
|
||||
record = SimpleNamespace(episode_index=0, frame_timestamps=[i / 10 for i in range(30)])
|
||||
|
||||
frames = provider.frames_at(record, [0.149, 1.234, 2.04], camera_key="observation.images.cam")
|
||||
|
||||
assert len(frames) == 3
|
||||
for frame in frames:
|
||||
assert isinstance(frame, torch.Tensor)
|
||||
assert frame.shape == (3, 120, 160)
|
||||
|
||||
|
||||
def test_decode_returns_empty_list_on_missing_file(tmp_path: Path, monkeypatch) -> None:
|
||||
"""A missing video is a recoverable no-frames condition, never a crash."""
|
||||
provider = _provider_for_video(tmp_path, tmp_path / "does_not_exist.mp4", monkeypatch)
|
||||
assert provider._decode(0, [0.0], "observation.images.cam") == []
|
||||
|
||||
|
||||
def test_episode_clip_path_trims_via_reencode_video(tmp_path: Path, monkeypatch) -> None:
|
||||
"""Clip extraction delegates to ``video_utils.reencode_video`` with the
|
||||
episode's ``[from_timestamp, to_timestamp)`` trim window — no subprocess.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import lerobot.annotations.steerable_pipeline.frames as frames_mod
|
||||
|
||||
src = tmp_path / "src.mp4"
|
||||
src.write_bytes(b"src")
|
||||
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=src)
|
||||
fake.episodes[0]["videos/observation.images.cam/from_timestamp"] = 1.5
|
||||
fake.episodes[0]["videos/observation.images.cam/to_timestamp"] = 4.0
|
||||
import lerobot.datasets.dataset_metadata as meta_mod
|
||||
|
||||
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_reencode(
|
||||
input_video_path,
|
||||
output_video_path,
|
||||
camera_encoder=None,
|
||||
overwrite=False,
|
||||
start_time_s=None,
|
||||
end_time_s=None,
|
||||
):
|
||||
captured.update(
|
||||
src=Path(input_video_path),
|
||||
encoder=camera_encoder,
|
||||
start_time_s=start_time_s,
|
||||
end_time_s=end_time_s,
|
||||
)
|
||||
Path(output_video_path).write_bytes(b"clip")
|
||||
|
||||
monkeypatch.setattr(frames_mod, "reencode_video", fake_reencode, raising=True)
|
||||
provider = VideoFrameProvider(root=tmp_path)
|
||||
record = SimpleNamespace(episode_index=0, frame_timestamps=[0.0, 1.0])
|
||||
|
||||
out = provider.episode_clip_path(record, tmp_path / "clips")
|
||||
|
||||
assert out == tmp_path / "clips" / "ep_000000.mp4"
|
||||
assert captured["src"] == src
|
||||
assert captured["start_time_s"] == 1.5
|
||||
assert captured["end_time_s"] == 4.0
|
||||
# H.264 so the clip is decodable by vllm's libav build (sources are often AV1).
|
||||
assert captured["encoder"].vcodec == "h264"
|
||||
|
||||
|
||||
def test_videoframeprovider_serializes_decodes_with_a_lock() -> None:
|
||||
"""torchcodec's cached per-file decoder is single-threaded; the provider
|
||||
must own a dedicated lock that ``_decode`` holds around the decoder call.
|
||||
"""
|
||||
import threading
|
||||
|
||||
lock_field = VideoFrameProvider.__dataclass_fields__.get("_decode_lock")
|
||||
assert lock_field is not None
|
||||
assert lock_field.default_factory is threading.Lock
|
||||
@@ -1,390 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1/2/3 unit tests with stubbed VLMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import PIL.Image
|
||||
import pytest
|
||||
|
||||
# ``lerobot.annotations`` imports pull in ``lerobot.datasets`` (-> the HF
|
||||
# ``datasets`` library), which only ships under the ``dataset`` extra. Skip
|
||||
# this module in tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient # noqa: E402
|
||||
|
||||
from ._helpers import make_canned_responder # noqa: E402
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StubFrameProvider:
|
||||
"""Returns one sentinel object per requested timestamp."""
|
||||
|
||||
# A real (tiny) PIL image so the contact-sheet builder, which resizes and
|
||||
# tiles frames, has something to draw. VQA still passes it through by
|
||||
# identity via ``to_image_blocks``.
|
||||
sentinel: Any = field(default_factory=lambda: PIL.Image.new("RGB", (32, 24)))
|
||||
cameras: tuple[str, ...] = ("observation.images.top",)
|
||||
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
|
||||
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return list(self.cameras)
|
||||
|
||||
def frames_at(self, record, timestamps, camera_key=None):
|
||||
self.calls.append((record.episode_index, tuple(timestamps), camera_key))
|
||||
return [self.sentinel] * len(timestamps)
|
||||
|
||||
def video_for_episode(self, record, max_frames, camera_key=None):
|
||||
self.video_calls.append((record.episode_index, max_frames, camera_key))
|
||||
n = min(max_frames, len(record.frame_timestamps))
|
||||
return [self.sentinel] * n
|
||||
|
||||
|
||||
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
return reply
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||
]
|
||||
},
|
||||
"compressed semantic memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig())
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert {"subtask", "plan", "memory"}.issubset(styles)
|
||||
# subtask timestamps must be exact frame timestamps
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for row in rows:
|
||||
assert row["timestamp"] in frame_set
|
||||
# one plan row per subtask boundary; the first lands at t0 and each
|
||||
# plan is the deterministic numbered list of still-todo subtasks
|
||||
plan_rows = sorted((r for r in rows if r["style"] == "plan"), key=lambda r: r["timestamp"])
|
||||
subtask_rows = [r for r in rows if r["style"] == "subtask"]
|
||||
assert len(plan_rows) == len(subtask_rows)
|
||||
assert plan_rows[0]["timestamp"] == record.frame_timestamps[0]
|
||||
# the t0 plan enumerates all subtasks; later plans shrink
|
||||
assert plan_rows[0]["content"].startswith("1. ")
|
||||
assert len(plan_rows[0]["content"].splitlines()) == len(subtask_rows)
|
||||
assert len(plan_rows[-1]["content"].splitlines()) == 1
|
||||
|
||||
|
||||
def test_module1_emit_memory_false_skips_memory_keeps_subtasks_and_plan(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""``emit_memory=False`` drops ``memory`` rows (and their VLM calls) while
|
||||
leaving subtask + plan generation intact — symmetric to ``emit_plan``."""
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||
]
|
||||
},
|
||||
"compressed semantic memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig(emit_memory=False))
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("plan")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert "memory" not in styles
|
||||
assert {"subtask", "plan"}.issubset(styles)
|
||||
|
||||
|
||||
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{"acknowledgement the robot": {"text": "Sure, on it."}},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=InterjectionsConfig(max_interjections_per_episode=0),
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("interjections")
|
||||
assert len(rows) == 1
|
||||
only = rows[0]
|
||||
assert only["role"] == "assistant"
|
||||
assert only["style"] is None
|
||||
assert only["content"] is None
|
||||
assert only["timestamp"] == record.frame_timestamps[0]
|
||||
assert only["tool_calls"][0]["function"]["name"] == "say"
|
||||
|
||||
|
||||
def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Module 2 anchors interjections on Module 1's subtask boundaries.
|
||||
|
||||
The executor runs Module 1 first, then Module 2 reads the subtask
|
||||
rows back from the same staging tree (see
|
||||
``_mid_episode_interjections``). Reproduce that contract here by
|
||||
seeding the staging with two subtask rows so a single ``0 → 1``
|
||||
boundary exists for Module 2 to anchor on.
|
||||
"""
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"acknowledgement the robot": {"text": "OK."},
|
||||
# Marker matches the distinctive line of
|
||||
# ``interjections_interjection.txt`` ("Write ONE compact
|
||||
# interjection ..."). Keep this in sync with that prompt's
|
||||
# wording — the canned responder matches on substring.
|
||||
"Write ONE compact interjection": {
|
||||
"interjection": "now wipe the counter please",
|
||||
"speech": "On it.",
|
||||
},
|
||||
},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||
seed=7,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
# Seed Module 1's subtask staging so Module 2 has a boundary to
|
||||
# anchor on (it bails with zero rows when no spans exist — the
|
||||
# production executor guarantees Module 1 ran first).
|
||||
boundary_ts = float(record.frame_timestamps[len(record.frame_timestamps) // 2])
|
||||
staging.write(
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": float(record.frame_timestamps[0]),
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wipe the counter",
|
||||
"style": "subtask",
|
||||
"timestamp": boundary_ts,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("interjections")
|
||||
|
||||
interjections = [r for r in rows if r["style"] == "interjection"]
|
||||
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
|
||||
assert len(interjections) == 1
|
||||
assert len(speeches) >= 2 # initial t=0 + one paired with the interjection
|
||||
inter_t = interjections[0]["timestamp"]
|
||||
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||
|
||||
|
||||
def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=3),
|
||||
seed=1,
|
||||
frame_provider=_StubFrameProvider(cameras=("observation.images.top", "observation.images.wrist")),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("vqa")
|
||||
# every vqa row must carry a camera tag and one of the configured cameras
|
||||
for r in rows:
|
||||
assert r["style"] == "vqa"
|
||||
assert r.get("camera") in {"observation.images.top", "observation.images.wrist"}
|
||||
# at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera)
|
||||
user_keys = [(r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa"]
|
||||
assistant_keys = [
|
||||
(r["timestamp"], r["camera"]) for r in rows if r["role"] == "assistant" and r["style"] == "vqa"
|
||||
]
|
||||
assert len(user_keys) == len(set(user_keys))
|
||||
assert len(assistant_keys) == len(set(assistant_keys))
|
||||
# both cameras must be represented
|
||||
assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"}
|
||||
# every emitted timestamp must be an exact source frame timestamp
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for ts, _ in user_keys + assistant_keys:
|
||||
assert ts in frame_set
|
||||
|
||||
|
||||
def test_module1_attaches_contact_sheets_to_subtask_prompt(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Module 1 sends timestamped contact-sheet image blocks (not a raw video block)."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.5},
|
||||
{"text": "wipe the counter", "start": 0.5, "end": 1.1},
|
||||
]
|
||||
}
|
||||
memory_payload = {"memory": "wiped once"}
|
||||
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
text = ""
|
||||
for m in messages:
|
||||
for block in m.get("content", []):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if "compressed semantic memory" in text:
|
||||
return memory_payload
|
||||
return payload
|
||||
|
||||
provider = _StubFrameProvider()
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=StubVlmClient(responder=responder),
|
||||
# Disable the rephrasings sub-prompt so the test's only video-bearing
|
||||
# call is the subtask one — keeps the assertions below focused on
|
||||
# ``_generate_subtasks`` rather than fighting the order of unrelated
|
||||
# text-only Module-1 sub-prompts.
|
||||
config=PlanConfig(frames_per_second=2.0, max_frames_per_prompt=60, n_task_rephrasings=0),
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
# Find the call carrying the subtask prompt rather than blindly taking
|
||||
# captured[0] — Module 1 issues several sub-prompts and their order is
|
||||
# not part of the contract.
|
||||
assert captured, "no VLM calls made"
|
||||
|
||||
def _prompt_text(messages):
|
||||
for m in messages:
|
||||
for block in m.get("content", []):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
subtask_calls = [m for m in captured if "atomic subtasks" in _prompt_text(m)]
|
||||
assert len(subtask_calls) == 1, "expected exactly one subtask-prompt VLM call"
|
||||
content = subtask_calls[0][0]["content"]
|
||||
video_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "video"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert video_blocks == [], "contact-sheet mode must not emit a raw video block"
|
||||
assert len(image_blocks) >= 1, f"expected >=1 contact-sheet image block, got {content}"
|
||||
assert all(isinstance(b["image"], PIL.Image.Image) for b in image_blocks)
|
||||
assert len(text_blocks) == 1
|
||||
# the prompt is prefixed with the contact-sheet reading instructions
|
||||
assert text_blocks[0]["text"].startswith("CONTACT SHEETS")
|
||||
# frames were decoded for this episode at episode-relative timestamps
|
||||
assert provider.calls and provider.calls[0][0] == record.episode_index
|
||||
|
||||
|
||||
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
"""Each VQA prompt must carry a single image block at the emission frame."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
}
|
||||
provider = _StubFrameProvider()
|
||||
module = GeneralVqaModule(
|
||||
vlm=_spy_responder(captured, payload),
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=1),
|
||||
seed=0,
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
assert captured, "no VLM calls made"
|
||||
for messages in captured:
|
||||
content = messages[0]["content"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||
assert image_blocks[0]["image"] is provider.sentinel
|
||||
assert len(text_blocks) == 1
|
||||
# provider was called once per emission per camera with the exact emission timestamp
|
||||
for ep_idx, ts_tuple, camera in provider.calls:
|
||||
assert ep_idx == record.episode_index
|
||||
assert len(ts_tuple) == 1
|
||||
assert ts_tuple[0] in record.frame_timestamps
|
||||
assert camera in provider.cameras
|
||||
|
||||
|
||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "Where is the cup?",
|
||||
"answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
seed=2,
|
||||
frame_provider=_StubFrameProvider(),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("vqa")
|
||||
for row in rows:
|
||||
if row["role"] == "assistant" and row["style"] == "vqa":
|
||||
decoded = json.loads(row["content"])
|
||||
assert "detections" in decoded
|
||||
@@ -1,183 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end smoke: pipeline output → canonical recipe rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``pyarrow`` and the ``lerobot.datasets`` chain (-> the HF ``datasets``
|
||||
# library) only ship under the ``dataset`` extra. Skip this module in
|
||||
# tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import ( # noqa: E402
|
||||
AnnotationPipelineConfig,
|
||||
InterjectionsConfig,
|
||||
PlanConfig,
|
||||
VqaConfig,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.modules import ( # noqa: E402
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter # noqa: E402
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
||||
from lerobot.datasets.language_render import render_sample # noqa: E402
|
||||
|
||||
from ._helpers import make_canned_responder # noqa: E402
|
||||
|
||||
|
||||
def _build_style_blend_recipe() -> TrainingRecipe:
|
||||
"""Inline blend recipe that consumes every style this pipeline produces.
|
||||
|
||||
The language schema/DSL work used to ship
|
||||
``src/lerobot/configs/recipes/pi05_hirobot.yaml`` as a canonical
|
||||
example, but that file was dropped during review. The contract this
|
||||
test guards is "the recipe DSL can render non-empty messages from
|
||||
pipeline output", which doesn't require a specific YAML — so we build
|
||||
the equivalent blend in code.
|
||||
"""
|
||||
return TrainingRecipe(
|
||||
blend={
|
||||
"low_level_execution": TrainingRecipe(
|
||||
weight=0.35,
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
|
||||
stream="high_level",
|
||||
),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
],
|
||||
),
|
||||
"user_interjection_response": TrainingRecipe(
|
||||
weight=0.16,
|
||||
bindings={
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${interjection}",
|
||||
stream="high_level",
|
||||
if_present="interjection",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${plan}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="plan",
|
||||
tool_calls_from="speech",
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_executor() -> Executor:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 0.5},
|
||||
{"text": "pour into the cup", "start": 0.5, "end": 1.0},
|
||||
{"text": "place the bottle down", "start": 1.0, "end": 1.5},
|
||||
]
|
||||
},
|
||||
"compressed semantic memory": {"memory": "poured once"},
|
||||
"acknowledgement the robot": {"text": "Sure."},
|
||||
"compact interjection": {
|
||||
"interjection": "use less water",
|
||||
"speech": "Using less water.",
|
||||
},
|
||||
"frame-grounded visual question": {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
config = AnnotationPipelineConfig(
|
||||
plan=PlanConfig(),
|
||||
interjections=InterjectionsConfig(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||
vqa=VqaConfig(vqa_emission_hz=1.0, K=2),
|
||||
)
|
||||
return Executor(
|
||||
config=config,
|
||||
plan=PlanSubtasksMemoryModule(vlm=vlm, config=config.plan),
|
||||
interjections=InterjectionsAndSpeechModule(vlm=vlm, config=config.interjections, seed=config.seed),
|
||||
vqa=GeneralVqaModule(vlm=vlm, config=config.vqa, seed=config.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
|
||||
def test_canonical_recipe_renders_nonempty_from_pipeline_output(
|
||||
single_episode_root: Path,
|
||||
) -> None:
|
||||
executor = _build_executor()
|
||||
summary = executor.run(single_episode_root)
|
||||
# validator may emit warnings but no errors for the synthetic fixture
|
||||
assert summary.validation_report.ok, summary.validation_report.summary()
|
||||
|
||||
table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent_lists = table.column("language_persistent").to_pylist()
|
||||
events_lists = table.column("language_events").to_pylist()
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
|
||||
recipe = _build_style_blend_recipe()
|
||||
|
||||
rendered_any = False
|
||||
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=float(ts),
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "Pour water from the bottle into the cup."},
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
if result["messages"]:
|
||||
rendered_any = True
|
||||
assert result["target_message_indices"]
|
||||
break
|
||||
assert rendered_any, "recipe rendered no messages from pipeline output"
|
||||
|
||||
# Sanity: speech atom appears in events column intact
|
||||
flat_events = [r for ev in events_lists for r in ev]
|
||||
speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"]
|
||||
assert speech_rows
|
||||
say = speech_rows[0]["tool_calls"][0]
|
||||
assert say["function"]["name"] == "say"
|
||||
assert isinstance(say["function"]["arguments"]["text"], str)
|
||||
# The pipeline does not write a ``tools`` column — the say schema lives
|
||||
# as a constant (``SAY_TOOL_SCHEMA``) so the language row struct is the
|
||||
# single source of truth for the v3.1 schema.
|
||||
assert "tools" not in table.column_names
|
||||
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Validator behavior tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``lerobot.annotations`` imports pull in ``lerobot.datasets`` (-> the HF
|
||||
# ``datasets`` library), which only ships under the ``dataset`` extra. Skip
|
||||
# this module in tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.writer import speech_atom # noqa: E402
|
||||
|
||||
|
||||
def _validate(root: Path, staging_dir: Path):
|
||||
records = list(iter_episodes(root))
|
||||
return StagingValidator().validate(records, staging_dir)
|
||||
|
||||
|
||||
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"vqa",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({"label": "cup", "count": 2}, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": 9.999, # not on any 10 fps frame
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("does not match any source frame timestamp" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
# interjection at 0.3s with NO paired speech
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip it",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("paired speech" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"plan",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do x",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do x",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"interjections",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
speech_atom(0.4, "Replanning."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "replan",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
# missing co-timestamped plan refresh at 0.4s → error
|
||||
assert not report.ok
|
||||
assert any("co-timestamped plan update" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"plan",
|
||||
[
|
||||
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("plan emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for ``vlm_client`` helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import _bind_serve_port # noqa: E402
|
||||
|
||||
|
||||
def test_bind_serve_port_substitutes_placeholder() -> None:
|
||||
# The {port} placeholder is replaced everywhere it appears, regardless of
|
||||
# parallel vs single server — the bug was the single-server path passing
|
||||
# it through unsubstituted.
|
||||
cmd = "vllm serve M --max-model-len 32768 --port {port}"
|
||||
assert _bind_serve_port(cmd, 8000) == "vllm serve M --max-model-len 32768 --port 8000"
|
||||
|
||||
|
||||
def test_bind_serve_port_appends_when_missing() -> None:
|
||||
assert _bind_serve_port("vllm serve M", 8001) == "vllm serve M --port 8001"
|
||||
|
||||
|
||||
def test_bind_serve_port_leaves_explicit_port_untouched() -> None:
|
||||
cmd = "vllm serve M --port 9000"
|
||||
assert _bind_serve_port(cmd, 8000) == cmd
|
||||
@@ -1,357 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Writer correctness tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# ``pyarrow`` and the ``lerobot.annotations`` -> ``lerobot.datasets`` chain
|
||||
# (-> the HF ``datasets`` library) only ship under the ``dataset`` extra.
|
||||
# Skip this module in tiers without it instead of erroring at import.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging # noqa: E402
|
||||
from lerobot.annotations.steerable_pipeline.writer import ( # noqa: E402
|
||||
LanguageColumnsWriter,
|
||||
speech_atom,
|
||||
)
|
||||
|
||||
|
||||
def _stage_episode(
|
||||
staging_dir: Path,
|
||||
episode_index: int,
|
||||
*,
|
||||
plan: list[dict] | None = None,
|
||||
interjections: list[dict] | None = None,
|
||||
vqa: list[dict] | None = None,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, episode_index)
|
||||
if plan is not None:
|
||||
staging.write("plan", plan)
|
||||
if interjections is not None:
|
||||
staging.write("interjections", interjections)
|
||||
if vqa is not None:
|
||||
staging.write("vqa", vqa)
|
||||
|
||||
|
||||
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Every frame in an episode has a byte-identical persistent list."""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. wipe\n2. dry",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wiped the counter",
|
||||
"style": "memory",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent = table.column("language_persistent").to_pylist()
|
||||
first = persistent[0]
|
||||
assert first # non-empty
|
||||
for row in persistent:
|
||||
assert row == first, "persistent slice must be byte-identical across all frames"
|
||||
|
||||
|
||||
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
interjections=[
|
||||
speech_atom(0.0, "Got it."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip the dishes",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.5, "Skipping the dishes."),
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
events = table.column("language_events").to_pylist()
|
||||
for ts, ev in zip(timestamps, events, strict=True):
|
||||
if abs(ts - 0.0) < 1e-9:
|
||||
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
|
||||
elif abs(ts - 0.5) < 1e-9:
|
||||
assert any(r.get("style") == "interjection" for r in ev), ev
|
||||
assert any(r.get("style") is None for r in ev), ev
|
||||
else:
|
||||
assert ev == []
|
||||
|
||||
|
||||
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do X",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "did X",
|
||||
"style": "memory",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
interjections=[
|
||||
speech_atom(0.0, "OK"),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "wait",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.2,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.2, "Waiting"),
|
||||
],
|
||||
vqa=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "where is the cup?",
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"camera": "observation.images.front",
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(
|
||||
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
|
||||
sort_keys=True,
|
||||
),
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"camera": "observation.images.front",
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
|
||||
persistent = table.column("language_persistent").to_pylist()[0]
|
||||
persistent_styles = {r["style"] for r in persistent}
|
||||
assert persistent_styles == {"subtask", "plan", "memory"}
|
||||
|
||||
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
|
||||
event_styles = {r.get("style") for r in all_events}
|
||||
assert event_styles == {None, "interjection", "vqa"}
|
||||
|
||||
|
||||
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
writer = LanguageColumnsWriter()
|
||||
writer.write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
|
||||
table_a = pq.read_table(path)
|
||||
assert "subtask_index" not in table_a.column_names
|
||||
assert "language_persistent" in table_a.column_names
|
||||
assert "language_events" in table_a.column_names
|
||||
# The writer no longer emits a dataset-level ``tools`` column; the
|
||||
# ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``)
|
||||
# so the parquet stays small and the pipeline doesn't extend the schema.
|
||||
assert "tools" not in table_a.column_names
|
||||
|
||||
# second pass — must produce identical bytes for the language columns
|
||||
records_again = list(iter_episodes(fixture_dataset_root))
|
||||
writer.write_all(records_again, staging_dir, fixture_dataset_root)
|
||||
table_b = pq.read_table(path)
|
||||
assert (
|
||||
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
|
||||
)
|
||||
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
|
||||
"""``_normalize_persistent_row`` must reject any non-persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
|
||||
|
||||
with pytest.raises(ValueError, match="non-persistent style"):
|
||||
_normalize_persistent_row(
|
||||
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
|
||||
)
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_event_style() -> None:
|
||||
"""``_normalize_event_row`` must reject any persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
|
||||
|
||||
|
||||
def test_say_tool_schema_constant_is_well_formed() -> None:
|
||||
"""``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet
|
||||
``tools`` column — chat-template consumers import them directly.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.writer import (
|
||||
DEFAULT_TOOLS,
|
||||
SAY_TOOL_SCHEMA,
|
||||
)
|
||||
|
||||
assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA]
|
||||
assert SAY_TOOL_SCHEMA["function"]["name"] == "say"
|
||||
params = SAY_TOOL_SCHEMA["function"]["parameters"]
|
||||
assert params["properties"]["text"]["type"] == "string"
|
||||
assert params["required"] == ["text"]
|
||||
|
||||
|
||||
def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Re-running on a parquet that already has a legacy ``tools`` column
|
||||
must drop it cleanly so reruns converge to the v3.1 schema.
|
||||
"""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
assert "tools" not in table.column_names
|
||||
|
||||
|
||||
def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Annotated parquet columns must be declared in ``meta/info.json``.
|
||||
|
||||
``LeRobotDataset`` loads non-streaming datasets by casting parquet
|
||||
against metadata-derived HF features. If the annotation writer adds
|
||||
language columns but metadata stays stale, that cast fails with a column
|
||||
mismatch.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import load_info, load_nested_dataset
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_feature_info
|
||||
|
||||
info_path = fixture_dataset_root / "meta" / "info.json"
|
||||
info = json.loads(info_path.read_text())
|
||||
info["features"] = {
|
||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
info_path.write_text(json.dumps(info, indent=2))
|
||||
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
plan=[
|
||||
{"role": "assistant", "content": "do X", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
Executor._ensure_annotation_metadata_in_info(fixture_dataset_root)
|
||||
|
||||
synced = load_info(fixture_dataset_root)
|
||||
for key, feature in language_feature_info().items():
|
||||
assert synced["features"][key] == feature
|
||||
|
||||
hf_features = get_hf_features_from_features(synced["features"])
|
||||
dataset = load_nested_dataset(fixture_dataset_root / "data", features=hf_features)
|
||||
|
||||
assert LANGUAGE_PERSISTENT in dataset.column_names
|
||||
assert LANGUAGE_EVENTS in dataset.column_names
|
||||
assert len(dataset) == 24
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
assert atom["style"] is None
|
||||
assert atom["content"] is None
|
||||
assert atom["timestamp"] == 2.5
|
||||
assert isinstance(atom["tool_calls"], list)
|
||||
call = atom["tool_calls"][0]
|
||||
assert call["type"] == "function"
|
||||
assert call["function"]["name"] == "say"
|
||||
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"
|
||||
@@ -289,52 +289,6 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
def test_aggregate_datasets_without_concatenation(tmp_path, lerobot_dataset_factory):
|
||||
"""With concatenation disabled, each source file is kept as its own destination file."""
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "no_stitch_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_no_stitch_0",
|
||||
total_episodes=3,
|
||||
total_frames=60,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "no_stitch_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_no_stitch_1",
|
||||
total_episodes=4,
|
||||
total_frames=80,
|
||||
)
|
||||
|
||||
aggr_root = tmp_path / "no_stitch_aggr"
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_no_stitch_aggr",
|
||||
aggr_root=aggr_root,
|
||||
concatenate_videos=False,
|
||||
concatenate_data=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(aggr_root)
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_no_stitch_aggr", root=aggr_root)
|
||||
|
||||
assert_episode_and_frame_counts(
|
||||
aggr_ds, ds_0.num_episodes + ds_1.num_episodes, ds_0.num_frames + ds_1.num_frames
|
||||
)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
assert_video_timestamps_within_bounds(aggr_ds)
|
||||
|
||||
# Two single-file sources stay as two files each, instead of being packed together.
|
||||
assert len(list((aggr_root / "data").rglob("*.parquet"))) == 2
|
||||
assert aggr_ds.meta.video_keys, "Test fixture should produce at least one video feature"
|
||||
for key in aggr_ds.meta.video_keys:
|
||||
assert len(list((aggr_root / "videos" / key).rglob("*.mp4"))) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"])
|
||||
def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
|
||||
tmp_path, lerobot_dataset_factory, caplog, mutation
|
||||
|
||||
@@ -83,29 +83,6 @@ def test_get_feature_stats_images():
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_uint8_images_preserves_std():
|
||||
data = np.array(
|
||||
[
|
||||
[
|
||||
[[0, 64], [128, 255]],
|
||||
[[255, 128], [64, 0]],
|
||||
[[32, 96], [160, 224]],
|
||||
],
|
||||
[
|
||||
[[16, 80], [144, 240]],
|
||||
[[240, 144], [80, 16]],
|
||||
[[48, 112], [176, 208]],
|
||||
],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
|
||||
expected_std = data.transpose(0, 2, 3, 1).reshape(-1, 3).std(axis=0).reshape(1, 3, 1, 1)
|
||||
np.testing.assert_allclose(stats["std"], expected_std)
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
|
||||
@@ -114,19 +114,6 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_is_reproducible_across_instances():
|
||||
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
|
||||
# produce the same permutation without any generator synchronization.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == epoch_0
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
@@ -150,87 +137,3 @@ def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
# --- seeded (seed, epoch) shuffling, resume, and state ---
|
||||
|
||||
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
|
||||
|
||||
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
|
||||
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
|
||||
for seed in (0, 1, 1234):
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
|
||||
assert sorted(sampler) == list(range(num_frames))
|
||||
|
||||
|
||||
def test_deterministic_sampler_epochs_reproduce_and_differ():
|
||||
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
|
||||
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
|
||||
assert epoch_1 != epoch_0
|
||||
assert sorted(epoch_1) == sorted(epoch_0)
|
||||
sampler_a.set_epoch(0)
|
||||
assert list(sampler_a) == epoch_0
|
||||
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_mid_epoch():
|
||||
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
epoch_0 = list(reference)
|
||||
epoch_1 = list(reference)
|
||||
for start in (0, 1, 4, len(epoch_0)):
|
||||
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
|
||||
assert list(resumed) == epoch_1
|
||||
|
||||
|
||||
def test_deterministic_sampler_construction_stores_only_boundaries():
|
||||
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
|
||||
# instantiates from just its boundaries without materializing a per-frame index list.
|
||||
num_frames = 1_000_000
|
||||
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
assert len(sampler) == num_frames
|
||||
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_is_exact_at_scale():
|
||||
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
|
||||
# permutation and slicing from the saved offset reproduces the remaining order exactly.
|
||||
num_frames = 100_000
|
||||
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
epoch_0 = list(reference)
|
||||
assert sorted(epoch_0) == list(range(num_frames))
|
||||
start = num_frames - 5
|
||||
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
|
||||
|
||||
def test_compute_sampler_state():
|
||||
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
|
||||
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 0,
|
||||
"start_index": 0,
|
||||
}
|
||||
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
|
||||
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
|
||||
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
|
||||
"epoch": 2,
|
||||
"start_index": 40,
|
||||
}
|
||||
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
|
||||
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
|
||||
"epoch": 1,
|
||||
"start_index": 100,
|
||||
}
|
||||
|
||||
@@ -504,19 +504,6 @@ class TestReencodeVideo:
|
||||
assert info["video.g"] == 6
|
||||
assert info["video.crf"] == 23
|
||||
|
||||
@require_h264
|
||||
def test_reencode_video_trim_window(self, tmp_path):
|
||||
src = TEST_ARTIFACTS_DIR / "clip_6frames.mp4"
|
||||
out = tmp_path / "trim_window.mp4"
|
||||
cfg = VideoEncoderConfig(vcodec="h264")
|
||||
reencode_video(src, out, camera_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True)
|
||||
|
||||
with av.open(str(out)) as container:
|
||||
frames = list(container.decode(video=0))
|
||||
# Only the frames at 0.067 and 0.1 s fall inside [0.05, 0.12).
|
||||
assert len(frames) == 2
|
||||
assert frames[0].time == pytest.approx(0.0, abs=1e-3)
|
||||
|
||||
|
||||
class TestConcatenateVideoFiles:
|
||||
def test_two_clips_frame_count(self, tmp_path):
|
||||
|
||||
Vendored
-61
@@ -552,64 +552,3 @@ def lerobot_dataset_factory(
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
||||
|
||||
def build_annotation_dataset(
|
||||
root: Path,
|
||||
episode_specs: list[tuple[int, int, str]],
|
||||
*,
|
||||
fps: int = 10,
|
||||
) -> Path:
|
||||
"""Build a minimal LeRobot-shaped dataset on disk for annotation tests.
|
||||
|
||||
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
|
||||
Each episode is written to its own
|
||||
``data/chunk-000/file-{ep:03d}.parquet`` so the writer's per-shard
|
||||
rewrite path is exercised. The dataset carries the minimum
|
||||
``meta/tasks.parquet`` + ``meta/info.json`` the reader / executor need;
|
||||
it has no videos, so the modules fall back to text-only prompts.
|
||||
|
||||
Shared by the annotation-pipeline pytest fixtures (``tests/annotations/
|
||||
conftest.py``) and the opt-in E2E smoke run so the fixture shape lives
|
||||
in exactly one place.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import write_tasks
|
||||
from lerobot.utils.io_utils import write_json
|
||||
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tasks: dict[int, str] = {}
|
||||
for episode_index, num_frames, task_text in episode_specs:
|
||||
if task_text not in tasks.values():
|
||||
tasks[len(tasks)] = task_text
|
||||
task_index = next(k for k, v in tasks.items() if v == task_text)
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": list(range(num_frames)),
|
||||
"timestamp": [round(i / fps, 6) for i in range(num_frames)],
|
||||
"task_index": [task_index] * num_frames,
|
||||
"subtask_index": [0] * num_frames, # legacy column the writer must drop
|
||||
}
|
||||
)
|
||||
frame.to_parquet(data_dir / f"file-{episode_index:03d}.parquet", index=False)
|
||||
|
||||
# Canonical tasks frame: indexed by task string with a ``task_index``
|
||||
# column, matching what ``lerobot.datasets.io_utils.load_tasks`` expects.
|
||||
tasks_df = pd.DataFrame(
|
||||
{"task_index": list(tasks.keys())},
|
||||
index=pd.Index(list(tasks.values()), name="task"),
|
||||
)
|
||||
write_tasks(tasks_df, root)
|
||||
|
||||
write_json(
|
||||
{
|
||||
"codebase_version": "v3.1",
|
||||
"fps": fps,
|
||||
"features": {},
|
||||
"total_episodes": len(episode_specs),
|
||||
},
|
||||
root / "meta" / "info.json",
|
||||
)
|
||||
return root
|
||||
|
||||
@@ -0,0 +1,518 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RECAP's distributional value function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_BINS = 201
|
||||
IMAGE_KEY = f"{OBS_IMAGES}.top"
|
||||
|
||||
|
||||
def _make_config(**overrides) -> DistributionalVFConfig:
|
||||
defaults = {
|
||||
"init_from_actor_path": "",
|
||||
"device": "cpu",
|
||||
"image_resolution": (224, 224),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
config = DistributionalVFConfig(**defaults)
|
||||
config.input_features = {
|
||||
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_model():
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel(_make_config())
|
||||
|
||||
|
||||
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
return {
|
||||
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
|
||||
"mc_return": torch.rand(batch_size, device=device) * -1.0,
|
||||
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
|
||||
def test_config_registered_in_reward_model_registry():
|
||||
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
|
||||
known = RewardModelConfig.get_known_choices()
|
||||
assert "distributional_value_function" in known
|
||||
|
||||
|
||||
def test_factory_returns_correct_class():
|
||||
"""get_reward_model_class returns DistributionalVFRewardModel."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
cls = get_reward_model_class("distributional_value_function")
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
assert cls is DistributionalVFRewardModel
|
||||
|
||||
|
||||
def test_make_reward_model_config_factory():
|
||||
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
|
||||
from lerobot.rewards.factory import make_reward_model_config
|
||||
|
||||
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
|
||||
assert isinstance(config, DistributionalVFConfig)
|
||||
assert config.num_value_bins == 101
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_sums_to_one():
|
||||
"""HL-Gauss target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_non_negative():
|
||||
"""HL-Gauss target probabilities are all non-negative."""
|
||||
model = _make_model()
|
||||
targets = torch.linspace(-1.0, 0.0, 10)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert (dist >= 0).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_expected_value_matches():
|
||||
"""E[V] under HL-Gauss distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_handles_2d_input():
|
||||
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (2, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_sums_to_one():
|
||||
"""Dirac delta target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
assert dist.shape == (5, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_at_most_two_nonzero():
|
||||
"""Dirac delta places probability on at most two adjacent bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.7523, -0.0013])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
for i in range(2):
|
||||
assert (dist[i] > 0).sum() <= 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_expected_value_matches():
|
||||
"""E[V] under Dirac delta distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_boundary_values_clamped():
|
||||
"""Values outside support are clamped to boundary bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-1.5, 0.5])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
|
||||
assert dist[0, 0] == 1.0
|
||||
assert dist[1, -1] == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_single_nonzero():
|
||||
"""One-hot target has exactly one non-zero bin per sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
for i in range(4):
|
||||
assert (dist[i] > 0).sum() == 1
|
||||
assert dist[i].sum() == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_nearest_bin():
|
||||
"""One-hot target activates the bin closest to the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
hot_idx = dist[0].argmax()
|
||||
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_terminal_gets_one_hot():
|
||||
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
|
||||
is_terminal = torch.tensor([False, True, False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
|
||||
assert (dist[1] > 0).sum() == 1
|
||||
assert (dist[3] > 0).sum() == 1
|
||||
assert (dist[0] > 0).sum() > 2
|
||||
assert (dist[2] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_no_terminal_override_when_disabled():
|
||||
"""When use_one_hot_terminal=False, terminal states use the base method."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3])
|
||||
is_terminal = torch.tensor([False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
|
||||
)
|
||||
|
||||
assert (dist[1] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_has_expected_components():
|
||||
"""Model scaffold contains all architectural components."""
|
||||
model = _make_model()
|
||||
|
||||
assert hasattr(model, "vision_tower")
|
||||
assert hasattr(model, "multi_modal_projector")
|
||||
assert hasattr(model, "token_embedding")
|
||||
assert hasattr(model, "layers")
|
||||
assert hasattr(model, "value_head")
|
||||
assert hasattr(model, "cls_embedding")
|
||||
assert hasattr(model, "norm")
|
||||
assert hasattr(model, "rotary_emb")
|
||||
assert hasattr(model, "bin_centers")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_bin_centers_shape():
|
||||
"""Bin centers buffer has shape (num_value_bins,)."""
|
||||
model = _make_model()
|
||||
assert model.bin_centers.shape == (NUM_BINS,)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_layer_count():
|
||||
"""Transformer has num_hidden_layers (6) layers."""
|
||||
model = _make_model()
|
||||
assert len(model.layers) == 6
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_value_head_output_dim():
|
||||
"""Value head outputs num_value_bins logits."""
|
||||
model = _make_model()
|
||||
assert model.value_head.out_features == NUM_BINS
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_returns_loss_and_dict():
|
||||
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, output_dict = model.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert "loss" in output_dict
|
||||
assert "predicted_value_mean" in output_dict
|
||||
assert "mc_return_mean" in output_dict
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_loss_is_positive():
|
||||
"""Cross-entropy loss is strictly positive for random weights."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_returns_correct_shape():
|
||||
"""compute_reward returns [batch_size] tensor of finite float32 values."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=3)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert values.shape == (3,)
|
||||
assert values.dtype == torch.float32
|
||||
assert torch.isfinite(values).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_values_in_support_range():
|
||||
"""Predicted values lie within [value_support_min, value_support_max]."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=8)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert (values >= -1.0 - 0.01).all()
|
||||
assert (values <= 0.0 + 0.01).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_processor_pipeline_produces_expected_keys():
|
||||
"""Full preprocessor pipeline produces tokenized text and processed images."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
config = _make_config()
|
||||
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
|
||||
|
||||
raw_batch = {
|
||||
IMAGE_KEY: torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
|
||||
processed = preprocessor(raw_batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert IMAGE_KEY in processed
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_value_head():
|
||||
"""Backprop produces non-zero gradients on the value head."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.value_head.weight.grad is not None
|
||||
assert not torch.all(model.value_head.weight.grad == 0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_cls_embedding():
|
||||
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.cls_embedding.grad is not None
|
||||
assert not torch.all(model.cls_embedding.grad == 0)
|
||||
|
||||
|
||||
def test_config_requires_visual_feature():
|
||||
"""validate_features raises if no VISUAL feature is present."""
|
||||
config = DistributionalVFConfig(init_from_actor_path="")
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="VISUAL"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_config_passes_with_visual_feature():
|
||||
"""validate_features succeeds when a VISUAL feature is present."""
|
||||
config = _make_config()
|
||||
config.validate_features()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_save_load_pretrained_roundtrip(tmp_path):
|
||||
"""Saved model can be loaded back with identical weights."""
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
model = _make_model()
|
||||
model._save_pretrained(tmp_path)
|
||||
|
||||
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
orig_sd = model.state_dict()
|
||||
loaded_sd = loaded.state_dict()
|
||||
|
||||
assert set(orig_sd.keys()) == set(loaded_sd.keys())
|
||||
for key in orig_sd:
|
||||
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_normalizes_to_minus_one_one():
|
||||
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 224, 224, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.min() >= -1.0 - 1e-5
|
||||
assert image.max() <= 1.0 + 1e-5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_resizes_with_pad():
|
||||
"""Image preprocessor resizes non-square images to target resolution."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 480, 640, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.shape[1:3] == (224, 224)
|
||||
|
||||
|
||||
def test_task_prompt_formats_correctly():
|
||||
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: pick up the cup."
|
||||
|
||||
|
||||
def test_task_prompt_handles_string_input():
|
||||
"""Task prompt step accepts a plain string (not just a list)."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: open drawer."
|
||||
|
||||
|
||||
def test_task_prompt_raises_on_missing_task():
|
||||
"""Task prompt step raises ValueError when task key is absent."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="No task found"):
|
||||
step(transition)
|
||||
@@ -66,20 +66,6 @@ class TestOperationTypeParsing:
|
||||
with pytest.raises(ValueError, match="--new_repo_id is required for merge"):
|
||||
_validate_config(cfg)
|
||||
|
||||
@pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"])
|
||||
def test_merge_concatenate_flag_defaults_true(self, flag):
|
||||
cfg = parse_cfg(["--new_repo_id", "test/merged", "--operation.type", "merge"])
|
||||
assert isinstance(cfg.operation, MergeConfig)
|
||||
assert getattr(cfg.operation, flag) is True
|
||||
|
||||
@pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"])
|
||||
def test_merge_concatenate_flag_can_be_disabled(self, flag):
|
||||
cfg = parse_cfg(
|
||||
["--new_repo_id", "test/merged", "--operation.type", "merge", f"--operation.{flag}", "false"]
|
||||
)
|
||||
assert isinstance(cfg.operation, MergeConfig)
|
||||
assert getattr(cfg.operation, flag) is False
|
||||
|
||||
def test_non_merge_requires_repo_id(self):
|
||||
cfg = parse_cfg(["--operation.type", "delete_episodes"])
|
||||
with pytest.raises(ValueError, match="--repo_id is required for delete_episodes"):
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
# ``lerobot.scripts.lerobot_annotate`` (and the ``_push_to_hub`` path it
|
||||
# exercises) imports ``lerobot.datasets``, which only ships under the
|
||||
# ``dataset`` extra. Skip in tiers without it instead of erroring.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
|
||||
def test_push_to_hub_tags_uploaded_dataset_revision(tmp_path, monkeypatch):
|
||||
from lerobot.scripts.lerobot_annotate import _push_to_hub
|
||||
|
||||
root = tmp_path / "dataset"
|
||||
(root / "meta").mkdir(parents=True)
|
||||
(root / "meta" / "info.json").write_text(json.dumps({"codebase_version": "v3.0"}))
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeHfApi:
|
||||
def create_repo(self, **kwargs):
|
||||
calls["create_repo"] = kwargs
|
||||
|
||||
def upload_folder(self, **kwargs):
|
||||
calls["upload_folder"] = kwargs
|
||||
return SimpleNamespace(oid="abc123")
|
||||
|
||||
def delete_tag(self, repo_id, **kwargs):
|
||||
import requests
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
calls["delete_tag"] = {"repo_id": repo_id, **kwargs}
|
||||
# Simulate the common case: no stale tag to delete.
|
||||
raise RevisionNotFoundError("no such tag", response=requests.Response())
|
||||
|
||||
def create_tag(self, **kwargs):
|
||||
calls["create_tag"] = kwargs
|
||||
|
||||
monkeypatch.setattr("huggingface_hub.HfApi", FakeHfApi)
|
||||
|
||||
cfg = SimpleNamespace(
|
||||
repo_id="source/dataset",
|
||||
new_repo_id="annotated/dataset",
|
||||
push_private=True,
|
||||
push_commit_message=None,
|
||||
)
|
||||
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
assert calls["create_repo"] == {
|
||||
"repo_id": "annotated/dataset",
|
||||
"repo_type": "dataset",
|
||||
"private": True,
|
||||
"exist_ok": True,
|
||||
}
|
||||
assert calls["upload_folder"]["repo_id"] == "annotated/dataset"
|
||||
# A stale tag (e.g. from a previous annotation run) is deleted first so
|
||||
# the new tag always points at the upload we just made.
|
||||
assert calls["delete_tag"] == {
|
||||
"repo_id": "annotated/dataset",
|
||||
"tag": "v3.0",
|
||||
"repo_type": "dataset",
|
||||
}
|
||||
assert calls["create_tag"] == {
|
||||
"repo_id": "annotated/dataset",
|
||||
"tag": "v3.0",
|
||||
"repo_type": "dataset",
|
||||
"revision": "abc123",
|
||||
}
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
@@ -26,16 +25,8 @@ def mock_metrics():
|
||||
|
||||
|
||||
class MockAccelerator:
|
||||
def __init__(self, num_processes: int, reduce_fn=None):
|
||||
def __init__(self, num_processes: int):
|
||||
self.num_processes = num_processes
|
||||
self.device = torch.device("cpu")
|
||||
self._reduce_fn = reduce_fn
|
||||
|
||||
def reduce(self, tensor, reduction="mean"):
|
||||
# In single-process tests we just want a deterministic stand-in for accelerate's reduce.
|
||||
if self._reduce_fn is not None:
|
||||
return self._reduce_fn(tensor, reduction)
|
||||
return tensor
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
@@ -166,70 +157,3 @@ def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker.reset_averages()
|
||||
assert tracker.loss.avg == 0.0
|
||||
assert tracker.accuracy.avg == 0.0
|
||||
|
||||
|
||||
def test_average_meter_invalid_reduction():
|
||||
with pytest.raises(ValueError):
|
||||
AverageMeter("loss", reduction="median")
|
||||
|
||||
|
||||
def test_average_meter_reduction_stored():
|
||||
meter = AverageMeter("updt_s", reduction="max")
|
||||
assert meter.reduction == "max"
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_no_accelerator():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op without accelerator
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_single_process():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=1),
|
||||
)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op when world size is 1
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_invokes_reduce():
|
||||
captured = {}
|
||||
|
||||
def fake_reduce(tensor, reduction):
|
||||
captured["reduction"] = reduction
|
||||
captured["values"] = tensor.clone()
|
||||
# Pretend the slowest rank reported 0.9 instead of this rank's 0.4.
|
||||
return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
metrics = {
|
||||
"loss": AverageMeter("loss"), # reduction="none" -> not touched
|
||||
"update_s": AverageMeter("update_s", reduction="max"),
|
||||
}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce),
|
||||
)
|
||||
tracker.loss = 1.0
|
||||
tracker.update_s = 0.4
|
||||
tracker.reduce_across_ranks()
|
||||
|
||||
assert captured["reduction"] == "max"
|
||||
assert torch.allclose(captured["values"], torch.tensor([0.4]))
|
||||
assert tracker.update_s.avg == pytest.approx(0.9)
|
||||
# Metrics without a reduction stay untouched.
|
||||
assert tracker.loss.avg == 1.0
|
||||
# Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls
|
||||
# accumulate against the cluster view rather than the stale per-rank sum.
|
||||
meter = tracker.update_s
|
||||
assert meter.sum / meter.count == pytest.approx(meter.avg)
|
||||
|
||||
@@ -20,8 +20,6 @@ from unittest.mock import Mock, patch
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
@@ -65,28 +63,6 @@ def test_load_training_step(tmp_path):
|
||||
assert loaded_step == step
|
||||
|
||||
|
||||
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
|
||||
assert load_training_num_processes(tmp_path) == 4
|
||||
|
||||
|
||||
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||
# Checkpoints written before the world size was recorded must still load (back-compat).
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert load_training_num_processes(tmp_path) is None
|
||||
|
||||
|
||||
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
|
||||
assert load_training_batch_size(tmp_path) == 32
|
||||
|
||||
|
||||
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||
# Checkpoints written before the batch size was recorded must still load (back-compat).
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert load_training_batch_size(tmp_path) is None
|
||||
|
||||
|
||||
def test_update_last_checkpoint(tmp_path):
|
||||
checkpoint = tmp_path / "0005"
|
||||
checkpoint.mkdir()
|
||||
|
||||
Reference in New Issue
Block a user