Compare commits

..

18 Commits

Author SHA1 Message Date
hq-fang c78023dae7 sync uv.lock with main 2026-05-27 15:49:59 +00:00
hq-fang 36d0ba5127 validate molmoact2 gripper range 2026-05-22 22:14:51 +00:00
hq-fang dca792951e fix molmoact2 pre-commit checks 2026-05-22 22:14:50 +00:00
hq-fang 0a369e104a skip molmoact2 tests without optional deps 2026-05-22 22:14:50 +00:00
hq-fang b0cdf99957 format molmoact2 files 2026-05-22 22:14:50 +00:00
hq-fang 733f9768b5 lazy import molmoact2 scipy 2026-05-22 22:14:50 +00:00
hq-fang 7fe49f9e54 load molmoact2 without remote code 2026-05-22 22:14:50 +00:00
hq-fang e1afb96474 fix molmoact2 hf image key resolution 2026-05-22 22:14:50 +00:00
hq-fang f395f36dec move molmoact2 config logic into config 2026-05-22 22:14:50 +00:00
hq-fang 738ba9272f use a single molmoact2 action queue 2026-05-22 22:14:50 +00:00
hq-fang 2a0495f8c3 add scipy dependency to molmoact2 extra 2026-05-22 22:14:50 +00:00
hq-fang c3c9c2b089 guard molmoact2 processor transformers import 2026-05-22 22:14:50 +00:00
hq-fang e13c6a6110 guard molmoact2 transformers imports 2026-05-22 22:14:50 +00:00
hq-fang 140cf2a420 remove molmoact2 processor override from factory 2026-05-22 22:14:50 +00:00
hq-fang c092194cf2 align molmoact2 feature validation with eo pattern 2026-05-22 22:14:50 +00:00
hq-fang b858ba1b6c simplify molmoact2 package imports 2026-05-22 22:14:50 +00:00
hq-fang e870af119f add apache headers to molmoact2 files 2026-05-22 22:14:50 +00:00
hq-fang 4174c3b303 add molmoact2 policy 2026-05-22 22:14:49 +00:00
173 changed files with 1983 additions and 17851 deletions
-3
View File
@@ -65,9 +65,6 @@ repos:
name: Format Markdown with Prettier
types_or: [markdown, mdx]
args: [--prose-wrap=preserve]
# Jinja2 model-card templates use a .md extension but contain {% ... %} /
# {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops).
exclude: ^src/lerobot/templates/.*\.md$
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
-6
View File
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
--env.episode_length=5 \
--eval.n_episodes=1 \
--eval.batch_size=1
# E2E annotation pipeline smoke test against a tiny in-memory fixture
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
# backend, so it does not require a real model checkpoint or GPU.
annotation-e2e:
uv run python -m tests.annotations.run_e2e_smoke
+7 -10
View File
@@ -58,7 +58,7 @@ action = model.select_action(obs)
robot.send_action(action)
```
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
@@ -101,13 +101,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
@@ -135,7 +133,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
## Citation
@@ -143,7 +140,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
-8
View File
@@ -9,8 +9,6 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: lelab
title: LeLab - Lerobot GUI
- local: bring_your_own_policies
title: Adding a Policy
- local: integrate_hardware
@@ -45,8 +43,6 @@
title: Language Columns and Recipes
- local: tools
title: Tools
- local: annotation_pipeline
title: Annotation Pipeline
- local: video_encoding_parameters
title: Video encoding parameters
- local: streaming_video_encoding
@@ -65,8 +61,6 @@
title: π₀.₅ (Pi05)
- local: molmoact2
title: MolmoAct2
- local: vla_jepa
title: VLA-JEPA
- local: eo1
title: EO-1
- local: groot
@@ -81,8 +75,6 @@
- sections:
- local: sarm
title: SARM
- local: robometer
title: ROBOMETER
- local: topreward
title: TOPReward
title: "Reward Models"
-291
View File
@@ -1,291 +0,0 @@
# Annotation Pipeline
`lerobot-annotate` watches each episode's video with a vision-language
model (VLM) and writes natural-language annotations back into your
dataset. It fills the two language columns from the
[Language Columns and Recipes](./language_and_recipes) page —
`language_persistent` and `language_events` — straight into
`data/chunk-*/file-*.parquet`.
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
memory, interjections, speech, and visual Q&A that a policy can be
trained on.
## How it fits together
```text
your dataset lerobot-annotate
(LeRobot v3.1)
┌─────────────────────────────────────────────────────┐
│ read episodes │
└──────────────────────────┬──────────────────────────┘
┌────────────────────┼────────────────────┐
▼ ▼ ▼
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
└────────────────────┼─────────────────────┘
│ each module stages raw JSONL
▼ into .annotate_staging/
┌─────────────────┐
│ validator │ ◀── checks everything
└────────┬────────┘
┌─────────────────┐
│ writer │
└────────┬────────┘
data/chunk-*/file-*.parquet
(+ meta/info.json tools)
```
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
VLM. Each module stages its output to disk, a validator checks it, and a
single writer rewrites the dataset shards in place.
## What the pipeline produces
Each module emits a few kinds of annotation ("styles"), routed to one of
the two language columns:
| Style / atom | Column | Module |
| ------------------------------------------- | --------------------- | --------------- |
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
| `interjection` | `language_events` | `interjections` |
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
### How subtasks are generated
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
it uses a two-step **describe → segment** flow:
1. **Describe** — the VLM narrates only what it actually sees in the
chosen camera (no guessing about the task).
2. **Segment** — that description is fed back in, and the VLM splits the
episode into consecutive atomic subtasks.
Both passes see the episode as **timestamped contact sheets** — frames
sampled at `frames_per_second` (0.5s by default) and packed into JPEG
grids with each frame's time burned into its corner, so the VLM cites
exact boundary times directly. This is far cheaper in vision tokens than
one image per frame, so the sampling can stay dense; episodes longer than
`max_frames_per_prompt` are split into windows at the same density and
merged. Both prompts also carry a causal **event-boundary** definition (a
new event starts when an object becomes held / is released / reaches a new
location / a lid changes state / contents move) to sharpen where cuts land.
The resulting spans are then stitched into a gap-free, full-episode
cover, so **every frame has exactly one active subtask**. See
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
for the production settings (single camera, timestamped contact sheets,
auto-windowed subtask generation).
### Tools
The writer does **not** add a `tools` column to the parquet. The tool
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
After every run, the pipeline makes sure the canonical `say` schema is in
that list, keeping any tools you declared beforehand.
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
pipeline preserves whatever is already there. That makes the tool visible
to the chat template, so the model can learn to _generate_ the call. The
runtime layer that actually _executes_ a generated call (the `Tool`
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
this PR — the [Tools](./tools) doc marks those pieces as
not-yet-implemented.
## Running on Hugging Face Jobs
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
The repo ships a launcher script you copy and tweak for your dataset:
```bash
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
```
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
that:
1. installs `lerobot` (from `main`) plus the annotation extras,
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
drives it over the OpenAI-compatible API,
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
with `lerobot-annotate`,
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
back to `--repo_id` in place if you leave that unset).
To use a different dataset, model, or hub repo, edit the `CMD` block in
the script. Every flag there maps directly to a `lerobot-annotate` flag
(run `lerobot-annotate --help` for the full list).
## Key options
These are the flags you'll reach for most often. Run
`lerobot-annotate --help` for everything else; the defaults are tuned for
short manipulation episodes.
### Dataset in / out
| Flag | Default | What it does |
| ----------------- | ------- | ----------------------------------------------------------------------- |
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
| `--root` | — | Annotate a local dataset directory instead. |
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
### Which modules run
Every module is on by default and can be toggled independently (set to
`false` to skip it, e.g. to iterate on one module at a time):
| Flag | Default | Turns off |
| ------------------------- | ------- | ----------------------------------- |
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
| `--interjections.enabled` | `true` | interjections + speech atoms |
| `--vqa.enabled` | `true` | the VQA pairs |
### The VLM (`--vlm.*`)
| Flag | Default | What it does |
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
| `--vlm.temperature` | `0.2` | Sampling temperature. |
### Subtasks / plan / memory (`--plan.*`)
| Flag | Default | What it does |
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). |
| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. |
| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). |
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. |
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
### Interjections + VQA
| Flag | Default | What it does |
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
## Contributing new modules
The pipeline is built to grow, and **contributions are very welcome** —
a brand-new module (say, trajectory traces or affordances), a new prompt
template, a smarter grounding flow, or quality fixes to the existing
`plan` / `interjections` / `vqa` modules.
Every module lives under
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
client and the keyframe cache, writes its raw output to the staging
tree, and plugs into the executor as its own phase. Got an idea? Open an
issue or PR on [the repo](https://github.com/huggingface/lerobot).
## How recipes consume the output
The annotations are meant to be read by recipes (see
[Language Columns and Recipes](./language_and_recipes)). Typically:
- low-level / high-level / memory-update branches read
`subtask` / `plan` / `memory` from `language_persistent`.
- an interjection-response branch reads `interjection` events plus the
paired speech atom (merged into one assistant turn via `tool_calls_from`)
and the matching `plan` refresh at the same timestamp.
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
`language_events`.
## Why state and events are split
Two ideas shape the design:
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
`plan`, `memory`) apply to the whole episode and answer "what's true
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
the one frame whose timestamp matches. Timestamps are copied straight
from the source parquet — never recomputed in floating point.
2. **One VLM pass.** All three modules share a single VLM client (the
OpenAI-compatible client talking to the job's vLLM server), so you pay
for one model load per dataset, not three.
## Re-running a single module
Each module stages its raw output to
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
prompt iteration cheap: re-running one module overwrites only its own
JSONL, then the writer recomposes the final parquet. Disable modules you
don't want with `--plan.enabled=false` (and likewise
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
## What the validator checks
Before the writer runs, `StagingValidator` confirms:
- every event row lands exactly on a real frame timestamp;
- no speech / interjection pairs are left orphaned;
- `plan` is refreshed at every interjection timestamp;
- `memory` rows fall on subtask boundaries (a warning, not an error);
- each VQA assistant `content` is valid JSON in one of the
bbox / keypoint / count / attribute / spatial shapes;
- every row goes to the column chosen by `column_for_style(style)`.
Any error aborts the writer. Pass `--skip_validation=true` to override
while debugging.
## Where each module's ideas come from
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
for atom granularity ("pick up one piece of lettuce", "place bowl to
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
for "how, not what" detail.
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
keep only the minimal relevant information — preserve outcomes, drop
specific attributes.
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
situated correction, specific constraint, preference. Speech is a
tool-call-only atom
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
keypoints) and Steerable VLA Policies
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
grounding. Pi0.7 also grounds answers across abstraction levels.
When improving a module, tweak its prompt template in
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
rewriting from scratch.
## Roughly how much it costs
Per episode, the pipeline makes about `max_steps` plan calls,
`max_interjections_per_episode` interjection calls, and
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
~50 VLM calls.
Storage stays small: `language_persistent` is at most tens of KB per
episode (parquet dictionary-encodes the one entry that repeats across
frames), and `language_events` is empty on most frames — its size scales
with the number of emissions, not `num_frames × num_emissions`.
-8
View File
@@ -157,14 +157,6 @@ finally:
</hfoption>
</hfoptions>
### Working with depth
The Intel RealSense and Reachy 2 cameras can capture both color and depth in lockstep. Calling `read()` returns the **color** frame as `(H, W, 3)` `uint8`. Calling `read_depth()` returns the **depth map** as `(H, W, 1)` `uint16`, where each pixel value is the distance from the sensor expressed in **millimetres**. A pixel value of `0` typically means "no measurement available" (out-of-range, occluded, or low-confidence).
During recording, the control loop peeks the freshest buffered frames non-blockingly via `read_latest()` (color) and `read_latest_depth()` (depth), adding the depth map as a sibling feature (e.g. `front_depth` next to `front`).
For how depth streams are stored and encoded when recording a dataset, see the [Depth streams](./video_encoding_parameters#depth-streams) section of the video encoding guide.
## Use your phone's camera
<hfoptions id="use phone">
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:**
- `bi_openarm_mini` - Bimanual OpenArm Mini
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \
-1
View File
@@ -647,6 +647,5 @@ The `--strategy.type` flag selects the execution mode:
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
- `episodic`: Episode-oriented policy recording with reset phases between episodes
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+1 -39
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=bi_openarm_mini \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt"
```
@@ -157,44 +157,6 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
### Episodic (`--strategy.type=episodic`)
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
```bash
lerobot-rollout \
--strategy.type=episodic \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so100_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/my_eval_data \
--dataset.num_episodes=20 \
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10 \
--dataset.single_task="Pick up the red cube"
```
Teleop is optional — if omitted the robot holds its position during the reset phase.
**Keyboard controls:**
| Key | Action |
| ----------- | -------------------------------- |
| `→` (right) | End the current episode early |
| `←` (left) | Discard episode and re-record it |
| `ESC` | Stop the recording session |
| Flag | Description |
| ----------------------------------------------- | -------------------------------------------------------------------------- |
| `--dataset.num_episodes` | Number of episodes to record |
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
---
## Inference Backends
-29
View File
@@ -1,29 +0,0 @@
# LeLab - LeRobot Guide
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
> [!TIP]
> For now LeLab is compatible only with SO-ARM101
<Youtube id="VqyKUuW9V1g" />
### Installation
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
```
uv tool install git+https://github.com/huggingface/leLab.git && lelab
```
After install, run `lelab` from your terminal anytime to start the app.
### Features
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
+1 -1
View File
@@ -275,7 +275,7 @@ A converter aggregates perepisode files into larger shards and writes episode
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
# Convert an existing v2.1 dataset hosted on the Hub:
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
```
**What it does**
+1 -1
View File
@@ -238,7 +238,7 @@ your dataset has not been converted with quantile statistics, you can add them
with:
```bash
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=your_dataset
```
+1 -1
View File
@@ -91,7 +91,7 @@ lerobot-train \
If your dataset is not converted with `quantiles`, you can convert it with the following command:
```bash
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=your_dataset \
```
-39
View File
@@ -1,39 +0,0 @@
# VLA-JEPA
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
---
## Architecture Overview
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
At inference time only the Qwen backbone and action head are used; the world model is not needed.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
+1 -1
View File
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
If you have existing datasets in v2.1 format, use the migration tool:
```bash
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id your_id/existing_dataset
```
-185
View File
@@ -1,185 +0,0 @@
# ROBOMETER
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
**Project**: [robometer.github.io](https://robometer.github.io/)
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
## Overview
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
- **Success head**: predicts per-frame task success probability.
- **Preference head**: predicts which of two trajectories better completes the task during training.
The paper trains ROBOMETER with a composite objective:
```text
L = L_pref + L_prog + L_succ
```
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
## What the LeRobot Integration Covers
- Standard `reward_model.type=robometer` configuration through LeRobot.
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
- Dense, frame-level progress and success predictions internally.
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install the ROBOMETER dependencies:
```bash
pip install -e ".[robometer]"
```
If you use `uv` directly from a source checkout:
```bash
uv sync --extra robometer
```
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
## Model Inputs and Outputs
ROBOMETER expects:
- A trajectory video or sequence of frames.
- A natural-language task description.
In LeRobot datasets, the preprocessor reads:
| Config field | Default | Meaning |
| ------------------------- | ------------------------ | ----------------------------------------------------- |
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
## Usage
### Load the Reward Model Directly
```python
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
cfg = RobometerConfig(
pretrained_path="lerobot/Robometer-4B",
device="cuda",
reward_output="progress",
)
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
```
### Encode Frames and Compute a Reward
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
```python
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
# task: str
encoder = RobometerEncoderProcessorStep(
base_model_id=cfg.base_model_id,
use_multi_image=cfg.use_multi_image,
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
max_frames=cfg.max_frames,
)
encoded = encoder.encode_samples([(frames, task)])
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
reward = reward_model.compute_reward(batch)
```
`reward` is a tensor of shape `(batch_size,)`.
### Use the Reward Factory
You can also instantiate ROBOMETER through the reward factory:
```python
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
cfg = make_reward_model_config(
"robometer",
pretrained_path="lerobot/Robometer-4B",
device="cuda",
image_key="observation.images.top",
)
reward_model = make_reward_model(cfg)
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
```
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
## Configuration Notes
### Backbone and Vocabulary
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
```text
<|split_token|>
<|reward_token|>
<|pref_token|>
<|sim_token|>
<|prog_token|>
```
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
### Progress Prediction
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
### Success Prediction
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
## Limitations
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
## References
- [ROBOMETER project](https://robometer.github.io/)
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
- [Original ROBOMETER code](https://github.com/robometer/robometer)
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
## Citation
```bibtex
@inproceedings{liang2026robometer,
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
year={2026},
booktitle={Robotics: Science and Systems 2026},
}
```
## License
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
+4 -45
View File
@@ -11,9 +11,8 @@ LeRobot provides several utilities for manipulating datasets:
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage (RGB and depth cameras are encoded with separate encoders)
7. **Re-encode Videos** - Re-encode an existing video dataset's RGB and/or depth streams with new encoder settings
8. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
@@ -123,15 +122,6 @@ lerobot-edit-dataset \
--operation.camera_encoder.g 2 \
--operation.camera_encoder.crf 30
# Convert a dataset that includes depth maps, customizing the depth encoder
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir outputs/pusht_video \
--operation.depth_encoder.depth_min 0.01 \
--operation.depth_encoder.depth_max 10.0 \
--operation.depth_encoder.use_log true
# Convert only specific episodes
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
@@ -157,42 +147,11 @@ lerobot-edit-dataset \
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `camera_encoder`: Video encoder settings applied to RGB cameras — all sub-fields accessible via `--operation.camera_encoder.<field>`. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
- `depth_encoder`: Video encoder settings applied to depth-map cameras (e.g. from an Intel RealSense). In addition to the standard encoder fields it exposes the depth quantization knobs (`depth_min`, `depth_max`, `shift`, `use_log`), accessible via `--operation.depth_encoder.<field>`. These quantization settings are persisted to the dataset metadata so depth can be dequantized back to physical units on load. See the [Depth streams](./video_encoding_parameters#depth-streams) section for details.
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). Depth-map cameras are detected automatically and routed to the `depth_encoder`, while RGB cameras use the `camera_encoder`. All episodes, stats, and tasks are preserved.
#### Re-encode Videos
Re-encode the videos of an existing video dataset with different encoder settings, without going back to raw frames. RGB videos use the `camera_encoder` and depth videos use the `depth_encoder`. Provide only the encoder(s) you want to re-encode; the other stream type is left untouched.
```bash
# Re-encode all RGB videos with new settings (saves to lerobot/pusht_reencoded by default)
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec h264 \
--operation.camera_encoder.pix_fmt yuv420p \
--operation.camera_encoder.crf 23
# Re-encode both RGB and depth videos in a dataset with depth maps
lerobot-edit-dataset \
--repo_id lerobot/pusht_depth \
--operation.type reencode_videos \
--operation.camera_encoder.vcodec libx264 \
--operation.depth_encoder.vcodec ffv1
```
**Parameters:**
- `camera_encoder`: Encoder settings applied to every RGB video. Omit to skip re-encoding RGB videos.
- `depth_encoder`: Encoder settings applied to every depth video. Omit to skip re-encoding depth videos.
- `num_workers`: Number of parallel workers for processing.
> [!NOTE]
> When re-encoding depth videos, the existing depth quantization parameters (`depth_min`, `depth_max`, `shift`, `use_log`) and the `is_depth_map` flag are **preserved** — re-encoding only changes the codec/quality of the stored stream, not how depth is dequantized on load.
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Show the information of datasets
+2 -72
View File
@@ -65,76 +65,6 @@ All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
---
## Depth streams
Depth maps (Intel RealSense, Reachy 2) are stored as their **own video streams** alongside the RGB streams. Raw depth (`uint16` millimetres or `float32` metres) can't survive an 8-bit codec, so LeRobot **quantizes** each map to a 12-bit code (`[0, 4095]`) — logarithmically by default, to match the `1/depth` error profile of depth sensors — then packs it into a high-bit-depth pixel format (`gray12le`) and encodes it with a 12-bit codec.
```mermaid
flowchart LR
A["Raw depth (uint16 mm / float32 m)"] --> B["Clip to depth_min, depth_max"]
B --> C["Quantize to 12-bit code 04095 (log or linear)"]
C --> D["Pack into gray12le"]
D --> E["Encode video (hevc Main 12)"]
E --> F[("MP4 + metadata: depth_min/max, shift, use_log")]
F -. "load time (depth_output_unit)" .-> G["Dequantize to mm or m"]
classDef input fill:#e3f2fd,stroke:#1565c0,color:#0d47a1;
classDef encode fill:#ede7f6,stroke:#5e35b1,color:#311b92;
classDef store fill:#fff8e1,stroke:#f9a825,color:#e65100;
classDef load fill:#e8f5e9,stroke:#2e7d32,color:#1b5e20;
class A input;
class B,C,D,E encode;
class F store;
class G load;
```
Configure the depth pipeline through a parallel **`depth_encoder`** block (`DepthEncoderConfig`). It inherits every `VideoEncoderConfig` field (`vcodec`, `pix_fmt`, `crf`, …) and adds four quantizer knobs, set via `--dataset.depth_encoder.<field>`:
```bash
lerobot-record \
... \
--dataset.depth_encoder.vcodec=hevc \
--dataset.depth_encoder.depth_min=0.05 \
--dataset.depth_encoder.depth_max=5.0 \
--dataset.depth_encoder.use_log=true
```
| Parameter | Type | Default | Description |
| ----------- | ------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- |
| `vcodec` | `str` | `"hevc"` | Defaults to HEVC Main 12 (a 12-bit-capable codec). `ffv1` is a lossless alternative. |
| `pix_fmt` | `str` | `"gray12le"` | Single-channel 12-bit pixel format used to carry the quantized codes. |
| `depth_min` | `float` | `0.01` | Depth in metres mapped to quantum `0`. Values below are clipped on decode. |
| `depth_max` | `float` | `10.0` | Depth in metres mapped to quantum `4095`. Values above are clipped on decode. |
| `shift` | `float` | `3.5` | Pre-log offset (metres) used in logarithmic quantization for numerical stability near zero. Must satisfy `depth_min + shift > 0`. |
| `use_log` | `bool` | `True` | If `true`, quantize in log-space (recommended for typical depth sensors). Set to `false` for uniform/linear quantization. |
> [!TIP]
> `depth_min`, `depth_max`, and `shift` are always interpreted in **metres**, regardless of the input depth's unit. Inputs are auto-detected: integer arrays (e.g. `uint16` millimetres straight from a RealSense) are treated as millimetres, floating arrays as metres.
> Pick `depth_min` / `depth_max` to bracket the actual working range of your sensor — quanta outside that range saturate, which can crush detail at the boundaries.
Depth features are flagged with `"is_depth_map": true` in `meta/info.json`, and their quantizer settings (`video.depth_min`, `video.depth_max`, `video.shift`, `video.use_log`) are persisted — which is what lets depth be **dequantized back to physical units** on load.
### Output unit at load time
`depth_encoder` is a **record-time** concern. The unit that depth maps are dequantized to on _load_ (e.g. during training) is set separately by the read-time flag `--dataset.depth_output_unit`:
```bash
lerobot-train \
--dataset.repo_id=<my_username>/<my_dataset_name> \
--dataset.depth_output_unit=m \
--policy.type=act
```
| Parameter | Type | Default | Description |
| ------------------- | ----- | ------- | -------------------------------------------------------------------------------------------- |
| `depth_output_unit` | `str` | `"mm"` | Physical unit depth maps are dequantized to on load: `"mm"` (millimetres) or `"m"` (metres). |
> [!TIP]
> This is purely a decode-time presentation choice — it does **not** alter the stored video or its metadata, so the same dataset can be read as `mm` or `m` without re-encoding. It has no effect on datasets without depth cameras.
---
## Persistence in dataset metadata
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
@@ -152,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
"video.pix_fmt": "yuv420p",
"video.fps": 30,
"video.channels": 3,
"is_depth_map": false,
"video.is_depth_map": false,
"video.g": 2,
"video.crf": 30,
"video.preset": "fast",
@@ -167,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
Two sources contribute to the `info` block:
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
<Tip>
-235
View File
@@ -1,235 +0,0 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Default behaviour — view padding / trimming (no action required)**
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
-77
View File
@@ -1,77 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
Spawns one single-GPU ``h200`` job that:
1. installs ``lerobot`` from ``main`` plus the annotation extras,
2. boots one vllm server with Qwen3.6-27B (dense VLM),
3. runs the plan / interjections / vqa modules across the dataset
in free-form mode (each episode generates its own subtasks +
memory),
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
or back to ``--repo_id``.
Usage:
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
run. For larger datasets, scale to ``h200x4`` and raise
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
"""
import os
from huggingface_hub import get_token, run_job
token = os.environ.get("HF_TOKEN") or get_token()
if not token:
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
CMD = (
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
"pip install --no-deps "
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
"pip install --upgrade-strategy only-if-needed "
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
"openai && "
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
"export VLLM_VIDEO_BACKEND=pyav && "
"lerobot-annotate "
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated "
"--push_to_hub=true "
"--vlm.backend=openai "
"--vlm.model_id=Qwen/Qwen3.6-27B "
"--vlm.num_gpus=1 "
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
"--tensor-parallel-size 1 --max-model-len 32768 "
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
"--vlm.serve_ready_timeout_s=1800 "
# Qwen3.6 ships with thinking on; annotation wants plain JSON answers.
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}'"
)
job = run_job(
image="vllm/vllm-openai:latest",
command=["bash", "-c", CMD],
flavor="h200",
secrets={"HF_TOKEN": token},
timeout="2h",
)
print(f"Job URL: {job.url}")
print(f"Job ID: {job.id}")
+10 -36
View File
@@ -115,8 +115,8 @@ dataset = [
]
training = [
"lerobot[dataset]",
"wandb>=0.24.0,<0.28.0",
"lerobot[accelerate-dep]",
"accelerate>=1.10.0,<2.0.0",
"wandb>=0.24.0,<0.25.0",
]
hardware = [
"lerobot[pynput-dep]",
@@ -142,8 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
@@ -178,12 +177,7 @@ unitree_g1 = [
"lerobot[matplotlib-dep]",
"lerobot[pygame-dep]",
]
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
reachy2 = [
"reachy2_sdk>=1.0.15,<1.1.0",
"grpcio<=1.73.1",
"protobuf<=6.32.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
# leader (motorbridge-smart-servo / FashionStar UART servos).
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
@@ -205,7 +199,7 @@ wallx = [
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
"lerobot[transformers-dep]",
@@ -218,43 +212,26 @@ groot = [
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
# which talks to any OpenAI-compatible server (``vllm serve`` /
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
# (see examples/annotations/run_hf_job.py).
annotations = [
"lerobot[dataset]",
"lerobot[transformers-dep]",
"openai>=1.40,<2.0",
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
# uv's single unified lock would then cap ``torch`` for every extra
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
# install it locally only if you run your own ``vllm serve``.
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
@@ -304,7 +281,6 @@ all = [
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -315,7 +291,6 @@ all = [
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
@@ -338,7 +313,6 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
@@ -357,7 +331,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
[tool.setuptools.package-data]
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
-15
View File
@@ -1,15 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -1,36 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Steerable annotation pipeline producing ``language_persistent`` and
``language_events`` columns for LeRobot datasets.
The pipeline is decomposed into three independently runnable modules whose
outputs are staged per-episode before a final parquet rewrite:
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
"""
from .config import AnnotationPipelineConfig
from .validator import StagingValidator, ValidationReport
from .writer import LanguageColumnsWriter
__all__ = [
"AnnotationPipelineConfig",
"LanguageColumnsWriter",
"StagingValidator",
"ValidationReport",
]
@@ -1,211 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class PlanConfig:
"""``plan`` module: subtasks + plan + memory + task augmentation."""
enabled: bool = True
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
n_task_rephrasings: int = 10
# Derive the task from video instead of episode_task: off / if_short / always.
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
derive_task_from_video: str = "if_short"
derive_task_min_words: int = 3
# --- Frame input: timestamped contact sheets (always on) ---------------
# The subtask describe/segment passes ALWAYS render the episode as
# macrodata/refiner-style contact sheets: sampled frames packed into JPEG
# grids with each frame's timestamp burned into its corner, so the VLM
# cites the exact source time of a boundary directly. This is far cheaper
# in vision tokens than one image per frame (≈2× faster subtask generation
# in practice), which is why the sampling is dense by default.
#
# ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s.
frames_per_second: float = 2.0
# Frame budget per VLM call (= columns × rows × sheets). When a whole
# episode sampled at ``frames_per_second`` exceeds this, the episode is
# AUTOMATICALLY split into consecutive windows of
# ``max_frames_per_prompt`` frames each (one describe→segment call per
# window, still at the full ``frames_per_second`` density), and the
# per-window spans are merged + stitched into one contiguous cover. So an
# episode of any length is always covered at the full sampling density.
max_frames_per_prompt: int = 60
contact_sheet_columns: int = 5
contact_sheet_frames_per_sheet: int = 20
contact_sheet_frame_width: int = 224
contact_sheet_quality: int = 84
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
# Narrate-only grounding pass before segmenting — best defense against subtasks
# invented from the task text (+1 VLM call/episode).
subtask_describe_first: bool = True
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
emit_plan: bool = True
# Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only.
# Symmetric counterpart of ``emit_plan``.
emit_memory: bool = True
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
@dataclass
class TaskAugAxesConfig:
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
enabled: bool = False
synonym_paraphrase: int = 3
omit_arm: int = 3
omit_orientation: int = 2
omit_grasp_method: int = 2
combined_omissions: int = 2
@dataclass
class InterjectionsConfig:
"""``interjections`` module: interjections + paired speech."""
enabled: bool = True
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
max_interjections_per_episode: int = 3
interjection_min_t: float = 2.0
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
interjection_window_seconds: float = 2.0
interjection_window_frames: int = 4
@dataclass
class VqaConfig:
"""``vqa`` module: general VQA."""
enabled: bool = True
vqa_emission_hz: float = 1.0
K: int = 1
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
# True: ground VQA only on --vlm.camera_key (default: every camera).
restrict_to_default_camera: bool = False
@dataclass
class VlmConfig:
"""Shared Qwen-VL client configuration."""
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
# auto_serve=True); ``stub`` is for tests.
backend: str = "openai"
model_id: str = "Qwen/Qwen3.6-27B"
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
api_base: str = "http://localhost:8000/v1"
api_key: str = "EMPTY"
# Spawn a server if none answers api_base; False = fail fast on a remote.
auto_serve: bool = True
serve_port: int = 8000
# Override the auto-serve command; ``{port}`` substituted per replica.
serve_command: str | None = None
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
parallel_servers: int = 1
num_gpus: int = 0
client_concurrency: int = 16
serve_ready_timeout_s: float = 600.0
max_new_tokens: int = 512
temperature: float = 0.2
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
max_model_len: int | None = None
# Camera for keyframes; None → first ``observation.images.*`` key.
camera_key: str | None = None
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
chat_template_kwargs: dict[str, Any] | None = None
@dataclass
class ExecutorConfig:
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
# Episodes processed concurrently per phase; main knob for saturating the servers.
episode_parallelism: int = 16
@dataclass
class AnnotationPipelineConfig:
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
# is on and ``new_repo_id`` unset.
repo_id: str | None = None
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
new_repo_id: str | None = None
root: Path | None = None
# Defaults to ``<root>/.annotate_staging/``.
staging_dir: Path | None = None
seed: int = 1729
plan: PlanConfig = field(default_factory=PlanConfig)
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
vqa: VqaConfig = field(default_factory=VqaConfig)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None
# Keyframe decode backend forwarded to ``decode_video_frames``. None →
# library default (torchcodec when available, else PyAV). Or pin
# ``"torchcodec"`` / ``"pyav"`` explicitly.
video_backend: str | None = None
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
push_to_hub: bool = False
push_private: bool = False
push_commit_message: str | None = None
def resolved_staging_dir(self, root: Path) -> Path:
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
@@ -1,253 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""In-process executor that runs the annotation phases.
The executor runs **six phases** in dependency order:
phase 1: ``plan`` module (plan + subtasks + memory)
phase 2: ``interjections`` module (interjections + speech)
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: ``vqa`` module (VQA)
phase 5: validator
phase 6: writer
Phase 3 is why the ``plan`` module must be re-entered after the
``interjections`` module — to refresh ``plan`` rows at interjection
timestamps.
Distributed execution is provided by Hugging Face Jobs (see
``examples/annotations/run_hf_job.py``); the runner inside the job
invokes ``lerobot-annotate`` which uses this in-process executor.
Episode-level concurrency is controlled by
``ExecutorConfig.episode_parallelism``.
"""
from __future__ import annotations
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .config import AnnotationPipelineConfig
from .reader import EpisodeRecord, iter_episodes
from .staging import EpisodeStaging
from .validator import StagingValidator
from .writer import LanguageColumnsWriter
logger = logging.getLogger(__name__)
@dataclass
class PhaseResult:
"""Summary of one pipeline phase across all episodes."""
name: str
episodes_processed: int
episodes_skipped: int
@dataclass
class PipelineRunSummary:
"""Aggregated result returned by :meth:`Executor.run`."""
phases: list[PhaseResult]
written_paths: list[Path]
validation_report: Any # ValidationReport, kept Any to avoid import cycle
@dataclass
class Executor:
"""Run all six phases over a dataset root in-process.
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
(a thread pool); cluster-level concurrency comes from running this
executor inside a Hugging Face Job. Tests construct the executor
directly with stub modules.
"""
config: AnnotationPipelineConfig
plan: Any # PlanSubtasksMemoryModule
interjections: Any # InterjectionsAndSpeechModule
vqa: Any # GeneralVqaModule
writer: LanguageColumnsWriter
validator: StagingValidator
def run(self, root: Path) -> PipelineRunSummary:
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
n = len(records)
if n == 0:
raise ValueError(f"No episodes found under {root}/data/")
print(f"[annotate] {n} episodes total", flush=True)
staging_dir = self.config.resolved_staging_dir(root)
staging_dir.mkdir(parents=True, exist_ok=True)
phases: list[PhaseResult] = []
# Phase 1: ``plan`` module (plan + subtasks + memory)
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
# Phase 2: ``interjections`` module (interjections + speech). It
# reads the ``plan`` module's subtask rows from the same staging
# tree to ground the interjection prompt in the correct local subtask.
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: ``vqa`` module (VQA)
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
print("[annotate] running validator...", flush=True)
report = self.validator.validate(records, staging_dir)
if not report.ok and not self.config.skip_validation:
raise RuntimeError(f"Staging validation failed: {report.summary()}")
print(f"[annotate] validator: {report.summary()}", flush=True)
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
written = self.writer.write_all(records, staging_dir, root)
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
# Keep meta/info.json aligned with the parquet schema we just wrote.
# Idempotent and additive: existing user metadata is preserved.
self._ensure_annotation_metadata_in_info(root)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
@staticmethod
def _ensure_annotation_metadata_in_info(root: Path) -> None:
"""Write language features and canonical tools to ``meta/info.json``.
``LanguageColumnsWriter`` adds ``language_persistent`` and
``language_events`` to parquet shards. The metadata must advertise
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
cast against the old schema and fail on the extra parquet columns.
"""
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
info_path = root / "meta" / "info.json"
if not info_path.exists():
return
try:
info = load_info(root)
except Exception as exc: # noqa: BLE001
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
return
changed = False
merged_features = {**info.features, **language_feature_info()}
if merged_features != info.features:
info.features = merged_features
changed = True
existing = info.tools or []
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
info.tools = [*existing, SAY_TOOL_SCHEMA]
changed = True
if changed:
write_info(info, root)
print(
"[annotate] meta/info.json: "
f"language_features={list(language_feature_info())}, "
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
flush=True,
)
def _run_module_phase(
self,
name: str,
records: list[EpisodeRecord],
staging_dir: Path,
module: Any,
) -> PhaseResult:
if not module.enabled:
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
n = len(records)
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
print(
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
flush=True,
)
t0 = time.time()
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
i, record = idx_record
ep_start = time.time()
staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging)
return i, record.episode_index, time.time() - ep_start
processed = 0
if parallelism == 1:
for i, record in enumerate(records, 1):
_, ep_idx, elapsed = _do((i, record))
processed += 1
print(
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
flush=True,
)
else:
with ThreadPoolExecutor(max_workers=parallelism) as pool:
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
for fut in as_completed(futures):
i, ep_idx, elapsed = fut.result()
processed += 1
print(
f"[annotate] {name} episode {processed}/{n} "
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
flush=True,
)
total = time.time() - t0
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
def _run_plan_update_phase( # noqa: PLR0915
self, records: list[EpisodeRecord], staging_dir: Path
) -> PhaseResult:
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
The ``plan`` module owns the prompt; the ``interjections`` module
produced the timestamps. This phase therefore calls back into the
``plan`` module with the interjection timestamps so its existing
prompt path is reused.
"""
if not self.plan.enabled or not self.interjections.enabled:
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
interjection_rows = [
row for row in staging.read("interjections") if row.get("style") == "interjection"
]
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
if interjection_times:
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
processed += 1
# Episodes without any interjections are skipped (no plan refresh
# needed); count them so the summary's processed+skipped == total.
return PhaseResult(
name="plan_update",
episodes_processed=processed,
episodes_skipped=len(records) - processed,
)
@@ -1,481 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keyframe extraction for the annotation pipeline.
Modules attach decoded camera frames to their VLM prompts so the model can
ground subtask decomposition, interjection scenarios, and VQA in actual
visual content. The pipeline shares one provider across modules and one
episode at a time, with a small per-episode cache so multiple modules
querying the same timestamp pay decode cost once.
"""
from __future__ import annotations
import io
import logging
import math
import threading
from collections.abc import Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
import PIL.Image
import torch
from lerobot.configs.video import VideoEncoderConfig
from lerobot.datasets.video_utils import decode_video_frames, reencode_video
from .reader import EpisodeRecord, snap_to_frame
logger = logging.getLogger(__name__)
class FrameProvider(Protocol):
"""Decodes camera frames at episode-relative timestamps."""
@property
def camera_keys(self) -> list[str]:
"""All ``observation.images.*`` feature keys this provider can decode."""
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
:func:`to_image_blocks` converts them to PIL only at the VLM-message
boundary.
Empty list if the camera is unavailable. ``camera_key=None`` falls back
to the provider's default camera so existing single-camera callers
(the ``plan`` and ``interjections`` modules) keep working unchanged.
"""
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
"""Return up to ``max_frames`` decoded frames covering the whole episode.
Sampling is uniform across the episode duration. Frames are
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
them into one ``{"type":"video", "video":<list>}`` block for a
Qwen-VL-compatible model that pools temporally itself. Empty list if
no camera available.
"""
@dataclass
class _NullProvider:
"""No-op provider used when the dataset has no video keys or in tests."""
@property
def camera_keys(self) -> list[str]:
return []
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
return []
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
return []
def null_provider() -> FrameProvider:
return _NullProvider()
@dataclass
class VideoFrameProvider:
"""Decodes frames from the dataset's ``observation.images.*`` streams.
By default the *first* camera key is used for the ``plan`` module
(subtask decomposition) and the ``interjections`` module (interjection
scenarios) — those prompts care about *what is happening*, not which
angle. The ``vqa`` module instead iterates over every camera in
:attr:`camera_keys` so each frame's
grounded answer (bbox/keypoint/...) is tagged with the camera it was
grounded against.
``camera_key`` overrides the default-camera choice but does not restrict
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
``video_for_episode`` to read a non-default stream.
Caches up to ``cache_size`` decoded frames per process to keep
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
"""
root: Path
camera_key: str | None = None
tolerance_s: float = 1e-2
cache_size: int = 256
# Keyframe decode backend forwarded to
# :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None``
# uses the library default (torchcodec when available, else PyAV).
video_backend: str | None = None
_meta: Any = field(default=None, init=False, repr=False)
_cache: dict = field(default_factory=dict, init=False, repr=False)
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
# one-shot warn flag against concurrent updates from worker threads.
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
# Serializes decode_video_frames calls: torchcodec hands out one
# ``VideoDecoder`` per file from a process-wide cache, and the decoder
# is not safe to drive from multiple threads at once.
_decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
def __post_init__(self) -> None:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
# Only ``video_keys`` are decodable here: the clip/decode paths read
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
# only for video-stored cameras. Image-stored cameras (also in
# ``camera_keys``) would KeyError, so restrict the list — and the
# default — to video keys.
keys = list(self._meta.video_keys)
# Last-resort fallback: if metadata didn't surface any video keys but
# the caller explicitly named a camera (``--vlm.camera_key=...``),
# trust them — the key is by definition known to exist on the dataset.
if not keys and self.camera_key:
keys = [self.camera_key]
self._camera_keys = keys
if self.camera_key is None:
self.camera_key = keys[0] if keys else None
@property
def camera_keys(self) -> list[str]:
"""All ``observation.images.*`` keys available on this dataset."""
return list(self._camera_keys)
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
target = camera_key if camera_key is not None else self.camera_key
if not timestamps or target is None:
return []
# Snap each request to the nearest real frame timestamp: callers
# sample uniform grids whose points land mid-frame, and
# ``decode_video_frames`` rejects queries farther than
# ``tolerance_s`` from a decodable frame. Snapping also dedupes
# repeat queries through the cache.
if record.frame_timestamps:
timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps]
out: list[Any] = []
misses: list[float] = []
miss_indices: list[int] = []
with self._lock:
for i, ts in enumerate(timestamps):
key = (record.episode_index, target, round(float(ts), 6))
cached = self._cache.get(key)
if cached is not None:
out.append(cached)
else:
out.append(None)
misses.append(float(ts))
miss_indices.append(i)
if misses:
decoded = self._decode(record.episode_index, misses, target)
# ``_decode`` returns exactly one frame per requested timestamp,
# or an empty list if decoding failed wholesale. A partial list
# would mean a frame/timestamp misalignment, so only pair them up
# when the counts match (``strict=True`` then guards regressions).
if len(decoded) == len(miss_indices):
with self._lock:
for i, frame in zip(miss_indices, decoded, strict=True):
out[i] = frame
key = (record.episode_index, target, round(float(timestamps[i]), 6))
if len(self._cache) >= self.cache_size:
self._cache.pop(next(iter(self._cache)))
self._cache[key] = frame
# filter out any None left over from decode failures
return [frame for frame in out if frame is not None]
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
The whole episode duration is covered; the model picks subtask
boundaries from the temporal pooling it does internally. Frames are
``torch.Tensor`` (see :meth:`frames_at`).
"""
target = camera_key if camera_key is not None else self.camera_key
if max_frames <= 0 or target is None or not record.frame_timestamps:
return []
n_frames = min(max_frames, len(record.frame_timestamps))
if n_frames == len(record.frame_timestamps):
timestamps = list(record.frame_timestamps)
else:
t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1]
if t_last <= t0:
timestamps = [float(t0)] * n_frames
else:
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
timestamps = [float(t0 + i * step) for i in range(n_frames)]
return self.frames_at(record, timestamps, camera_key=target)
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
Returns ``None`` if the dataset has no video tracks or extraction
failed. Skips re-extract when the cached clip already exists.
Re-encodes to H.264 via
:func:`lerobot.datasets.video_utils.reencode_video` so the resulting
mp4 is decodable by every downstream video processor — stream-copy
would inherit the source codec (often AV1 in modern LeRobot
datasets), which vllm's libav build cannot decode.
"""
if self.camera_key is None:
return None
cache_dir.mkdir(parents=True, exist_ok=True)
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
if out_path.exists() and out_path.stat().st_size > 0:
return out_path
ep = self._meta.episodes[record.episode_index]
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast")
try:
reencode_video(
src,
out_path,
video_encoder=encoder,
overwrite=True,
start_time_s=from_timestamp,
end_time_s=to_timestamp,
)
except Exception:
logger.warning(
"clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True
)
return None
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
(torchcodec when available, PyAV otherwise; ``video_backend`` pins
one explicitly). Returns one frame per requested timestamp, or ``[]``
if decoding failed — callers treat ``[]`` as "no frames available".
"""
ep = self._meta.episodes[episode_index]
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
try:
# The module phases decode under a ThreadPoolExecutor (see
# ``ExecutorConfig.episode_parallelism``) but torchcodec's cached
# per-file decoder is single-threaded, so serialize decodes on a
# dedicated lock. Frame extraction is a small fraction of episode
# wall time (VLM calls dominate), so the contention is cheap.
with self._decode_lock:
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
decoded = decode_video_frames(
video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True
)
return list(decoded)
except Exception as exc:
# Log loudly the first time so a silent vqa-module no-op (every
# prompt skipped because frames_at returned []) is debuggable from
# the job log instead of post-hoc parquet inspection. Subsequent
# failures stay quiet.
with self._lock:
already_warned = self._warned_decode_fail
if not already_warned:
self._warned_decode_fail = True
if not already_warned:
logger.warning(
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s",
episode_index,
camera_key,
video_path,
self.video_backend,
exc,
exc_info=exc,
)
return []
def make_frame_provider(
root: Path, camera_key: str | None = None, video_backend: str | None = None
) -> FrameProvider:
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
try:
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
except Exception:
return null_provider()
if provider.camera_key is None:
return null_provider()
return provider
def _frame_to_pil(frame: Any) -> Any:
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
straight from :func:`decode_video_frames`); PIL is only created here, at
the VLM-message boundary, because the chat backends expect PIL images /
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
"""
if not isinstance(frame, torch.Tensor):
return frame
array = frame.detach().cpu()
if array.ndim == 3 and array.shape[0] in (1, 3):
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
if array.shape[-1] == 1:
array = array.squeeze(-1)
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
"""Wrap a list of decoded frames as one Qwen-VL video block.
Returns ``[]`` when the list is empty, so the caller can splat the result
into a content array without a separate emptiness check.
"""
if not frames:
return []
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
"""Wrap a video file URL as one ``video_url`` block.
Used by the ``openai`` backend (transformers serve / vllm serve /
ktransformers serve), where the server handles frame sampling.
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
"""
if not url:
return []
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image:
"""Burn ``timestamp`` (seconds) into the top-left corner of ``image``.
A solid black badge with white text, so a VLM reading a contact sheet can
cite the exact source time of each tile (e.g. ``012.50s``) directly,
instead of the caller having to map tile position back to time. Mirrors
the macrodata/refiner contact-sheet convention.
"""
from PIL import ImageDraw, ImageFont
result = image.copy()
draw = ImageDraw.Draw(result)
font = ImageFont.load_default()
label = f"{timestamp:06.2f}s"
left, top, right, bottom = draw.textbbox((0, 0), label, font=font)
text_w, text_h = right - left, bottom - top
pad = max(3, round(min(image.width, image.height) * 0.018))
draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0))
draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font)
return result
def to_contact_sheet_blocks(
frames: Sequence[Any],
timestamps: Sequence[float],
*,
columns: int = 5,
frames_per_sheet: int = 20,
frame_width: int = 224,
quality: int = 84,
) -> list[dict[str, Any]]:
"""Pack decoded frames into timestamped JPEG contact-sheet image blocks.
Each frame is resized to ``frame_width`` wide, stamped with its
episode-relative timestamp, and tiled row-major into grids of
``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}``
block is returned per grid; many frames collapse into a few images, so a
long episode's temporal coverage stays dense at a fraction of the vision
tokens N separate frames would cost. ``frames`` and ``timestamps`` must be
aligned and equal length. Returns ``[]`` for empty input.
"""
from PIL import Image
if not frames:
return []
columns = max(1, columns)
frames_per_sheet = max(1, frames_per_sheet)
rows_per_sheet = math.ceil(frames_per_sheet / columns)
tiles: list[PIL.Image.Image] = []
for ts, frame in zip(timestamps, frames, strict=False):
img = _frame_to_pil(frame)
if not isinstance(img, PIL.Image.Image):
continue
img = img.convert("RGB")
if img.width != frame_width:
height = max(1, round(img.height * frame_width / img.width))
img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR)
tiles.append(_draw_timestamp_badge(img, float(ts)))
if not tiles:
return []
blocks: list[dict[str, Any]] = []
for start in range(0, len(tiles), frames_per_sheet):
chunk = tiles[start : start + frames_per_sheet]
cell_w = max(tile.width for tile in chunk)
cell_h = max(tile.height for tile in chunk)
sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0))
for i, tile in enumerate(chunk):
x = (i % columns) * cell_w
y = (i // columns) * cell_h
sheet.paste(tile, (x, y))
# JPEG round-trip at ``quality`` to match the refiner convention and
# shrink the wire payload; vision-token count is set by resolution, so
# the real saving is the grid packing, not the codec.
buf = io.BytesIO()
sheet.save(buf, format="JPEG", quality=quality)
buf.seek(0)
blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")})
return blocks
@@ -1,25 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .general_vqa import GeneralVqaModule
from .interjections_and_speech import InterjectionsAndSpeechModule
from .plan_subtasks_memory import PlanSubtasksMemoryModule
__all__ = [
"GeneralVqaModule",
"InterjectionsAndSpeechModule",
"PlanSubtasksMemoryModule",
]
@@ -1,248 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``vqa`` module: general VQA at a timed cadence.
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
consecutive frames, and every anchored frame gets its own VQA pair. Each
pair is grounded on that single anchor frame — there is no per-pair frame
window. For datasets with multiple cameras, every anchored frame produces
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
generated against that camera's frame and stamped with the matching
``camera`` field on the emitted rows. The resolver disambiguates via
``camera=...``; recipes that consume VQA do so through one sub-recipe
per camera (see ``recipes/pi05_hirobot.yaml``).
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
count, attribute, spatial. The assistant's ``content`` is a JSON string
whose schema depends on the question type. Malformed JSON triggers one
retry inside :meth:`VlmClient.generate_json`.
"""
from __future__ import annotations
import json
import logging
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import VqaConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..validator import classify_vqa_answer
from ..vlm_client import VlmClient
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
"""Return the relative frame indices to anchor VQA emissions to.
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
consecutive frames starting at the tick. Ticks fall on the nearest
available source frame timestamp.
"""
if hz <= 0 or k <= 0 or not frame_timestamps:
return []
t0 = frame_timestamps[0]
t_last = frame_timestamps[-1]
period = 1.0 / hz
indices: list[int] = []
t = t0
while t <= t_last + 1e-9:
# find the index of the nearest frame to t
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
for offset in range(k):
j = nearest_i + offset
if j >= len(frame_timestamps):
break
if not indices or indices[-1] != j:
indices.append(j)
t += period
# dedupe while preserving order
seen: set[int] = set()
deduped: list[int] = []
for i in indices:
if i in seen:
continue
seen.add(i)
deduped.append(i)
return deduped
@dataclass
class GeneralVqaModule:
"""Emit grounded VQA pairs at a timed cadence."""
vlm: VlmClient
config: VqaConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
_warned_no_camera: bool = field(default=False, init=False, repr=False)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
if not record.frame_timestamps:
staging.write("vqa", [])
return
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
)
cameras = self._target_cameras()
if not cameras:
# No camera available — emit nothing rather than producing
# untagged rows that would fail validation. Surface a loud one-
# time warning so this is never silently a no-op.
if not self._warned_no_camera:
logging.getLogger(__name__).warning(
"vqa module found no cameras on the frame provider — "
"every episode will emit zero VQA rows. Check that the "
"dataset declares observation.images.* features in "
"meta/info.json; passing --vlm.camera_key=<key> at the "
"CLI now also seeds the cameras list as a fallback."
)
self._warned_no_camera = True
staging.write("vqa", [])
return
# Build all messages first (one per (frame, camera)), then issue them
# as a single batched generate_json call so the client can fan them
# out concurrently.
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
for idx in anchor_idx:
ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types)
for camera in cameras:
messages = self._build_messages(record, qtype, ts, camera)
# Skip cameras that decoded to zero frames at this ts: no point
# asking the VLM to ground a bbox without an image.
if not _has_image_block(messages):
continue
per_call.append((ts, camera, qtype, messages))
if not per_call:
staging.write("vqa", [])
return
results = self.vlm.generate_json([m for _, _, _, m in per_call])
rows: list[dict[str, Any]] = []
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
qa = self._postprocess(result)
if qa is None:
continue
question, answer = qa
rows.append(
{
"role": "user",
"content": question,
"style": "vqa",
"timestamp": ts,
"camera": camera,
"tool_calls": None,
}
)
rows.append(
{
"role": "assistant",
"content": json.dumps(answer, sort_keys=True),
"style": "vqa",
"timestamp": ts,
"camera": camera,
"tool_calls": None,
}
)
staging.write("vqa", rows)
def _target_cameras(self) -> list[str]:
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
Defaults to every camera the provider exposes. Datasets with no
cameras (or test/null providers) yield an empty list, which makes
``run_episode`` a no-op.
When ``config.restrict_to_default_camera`` is set, VQA grounds on
only the provider's default camera (the single ``--vlm.camera_key``
stream), matching the plan / interjection modules so the whole
pipeline focuses on one view.
"""
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
if getattr(self.config, "restrict_to_default_camera", False):
default = getattr(self.frame_provider, "camera_key", None)
if default and default in all_cameras:
return [default]
# ``restrict_to_default_camera`` is set but the configured default
# isn't one the provider exposes. Returning it anyway would make
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
# fall through to every available camera instead.
if default:
logging.getLogger(__name__).warning(
"restrict_to_default_camera is set but camera_key=%r is not in the "
"provider's cameras %s; grounding VQA on all available cameras instead.",
default,
all_cameras,
)
return all_cameras
def _build_messages(
self,
record: EpisodeRecord,
question_type: str,
frame_timestamp: float,
camera_key: str,
) -> list[dict[str, Any]]:
prompt = load_prompt("vqa").format(
episode_task=record.episode_task,
question_type=question_type,
)
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
if not isinstance(result, dict):
return None
question = result.get("question")
answer = result.get("answer")
if not isinstance(question, str) or not question.strip():
return None
if not isinstance(answer, dict):
return None
# The validator will enforce shape; here we just sanity-check that the
# answer matches *some* known shape so we can drop garbage early.
if classify_vqa_answer(answer) is None:
return None
return question.strip(), answer
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
"""Return True if any user content block is a populated image block."""
for msg in messages:
content = msg.get("content")
if not isinstance(content, list):
continue
for block in content:
if isinstance(block, dict) and block.get("type") == "image":
return True
return False
@@ -1,211 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
Two sub-passes:
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
canonical task). No interjection row — the canonical task is already the
user utterance from ``meta/tasks.parquet``.
2. For mid-episode interruptions, emit a co-timestamped pair:
{role:user, style:interjection, content:<text>}
speech atom (role:assistant, style:None, tool_calls=[say(...)])
Both rows go in ``language_events`` at the same timestamp.
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
interjection timestamps to refresh the ``plan`` row at the same instant.
"""
from __future__ import annotations
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import InterjectionsConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
from ..writer import speech_atom
@dataclass
class InterjectionsAndSpeechModule:
"""Generate task-start speech and mid-episode interjection/speech pairs."""
vlm: VlmClient
config: InterjectionsConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
if record.frame_timestamps:
t0 = float(record.frame_timestamps[0])
initial = self._initial_speech(record)
if initial:
rows.append(speech_atom(t0, initial))
# Pull the ``plan`` module's subtask spans for this episode so the
# interjection prompt can ground itself in the actual current
# subtask at each chosen timestamp. The ``plan`` module ran first.
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
rows.extend(self._mid_episode_interjections(record, subtask_spans))
staging.write("interjections", rows)
@staticmethod
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
current: str | None = None
for span in spans:
if float(span["start"]) <= t:
current = span.get("text")
else:
break
return current
def _initial_speech(self, record: EpisodeRecord) -> str | None:
prompt = load_prompt("interjections_initial_speech").format(
episode_task=record.episode_task,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("text"), str):
text = result["text"].strip()
if text:
return text
return None
def _mid_episode_interjections(
self,
record: EpisodeRecord,
subtask_spans: Sequence[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Generate interjections aligned with the actual demo trajectory.
Teleop data is frozen — the robot already executed every step in
the video. A *counterfactual* interjection like "actually skip
the wipe" contradicts what then happens in the video, which is
what qwen36moe-10/11 surfaced as low-quality interjections.
Instead, anchor every interjection at a subtask boundary and
write it as a natural user request for the *upcoming* subtask.
The robot's visible next behavior IS the interjection's effect,
so the training signal stays consistent: interjection text →
plan refresh → action stream all line up.
"""
if self.config.max_interjections_per_episode <= 0:
return []
if len(subtask_spans) < 2:
# Need at least one transition (subtask 0 → subtask 1).
return []
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
# Boundaries: the start time of every subtask except the first
# (which is just t0 and is covered by the initial-task speech atom).
boundaries: list[tuple[float, str, str]] = []
for i in range(1, len(subtask_spans)):
ts = float(subtask_spans[i]["start"])
if ts < self.config.interjection_min_t:
continue
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
next_text = (subtask_spans[i].get("text") or "").strip()
if not next_text:
continue
boundaries.append((ts, prev_text, next_text))
if not boundaries:
return []
n = min(self.config.max_interjections_per_episode, len(boundaries))
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
out: list[dict[str, Any]] = []
for t, prev_subtask, next_subtask in chosen:
t_snap = snap_to_frame(t, record.frame_timestamps)
# Window straddles the boundary so the VLM sees the end of the
# previous subtask and the start of the next one — same
# conditioning the policy will see at training time.
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
prompt = load_prompt("interjections_interjection").format(
episode_task=record.episode_task,
prev_subtask=prev_subtask or "(starting from initial state)",
next_subtask=next_subtask,
timestamp=t_snap,
window_seconds=self.config.interjection_window_seconds,
)
images = self.frame_provider.frames_at(record, window_ts)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
continue
interjection_text = result.get("interjection")
speech_text = result.get("speech")
if not isinstance(interjection_text, str) or not interjection_text.strip():
continue
if not isinstance(speech_text, str) or not speech_text.strip():
continue
out.append(
{
"role": "user",
"content": interjection_text.strip(),
"style": "interjection",
"timestamp": t_snap,
"tool_calls": None,
}
)
out.append(speech_atom(t_snap, speech_text.strip()))
return out
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
"""Return a small set of frame timestamps centered on ``t_anchor``.
The window straddles the subtask boundary the interjection sits
on: roughly half the frames cover the end of the previous
subtask, half cover the start of the next one. The VLM therefore
sees BOTH what just finished AND what's about to start, which is
the conditioning we need to write a natural "now please do X"
request that matches the visible upcoming behavior.
"""
if not frame_timestamps:
return [t_anchor]
n = max(1, int(self.config.interjection_window_frames))
if n == 1:
return [t_anchor]
window = float(self.config.interjection_window_seconds)
step = window / max(1, n - 1)
# Center the window on the anchor so half lands before, half after.
start_offset = -window / 2.0
targets = [t_anchor + start_offset + step * i for i in range(n)]
first_ts = float(frame_timestamps[0])
last_ts = float(frame_timestamps[-1])
snapped: list[float] = []
seen: set[float] = set()
for tgt in targets:
clamped = min(last_ts, max(first_ts, tgt))
t = snap_to_frame(clamped, frame_timestamps)
if t not in seen:
seen.add(t)
snapped.append(t)
return snapped or [t_anchor]
@@ -1,780 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import PlanConfig
from ..frames import (
FrameProvider,
null_provider,
to_contact_sheet_blocks,
)
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
logger = logging.getLogger(__name__)
# Prepended to every describe / segment prompt so the VLM knows the images are
# timestamped contact-sheet grids, not a single video, and reads the burned-in
# per-tile timestamp when choosing boundaries.
def _contact_sheet_preamble(columns: int) -> str:
return (
"CONTACT SHEETS — how to read the images below:\n"
f"- Each image is a grid of sampled video frames, {columns} per row, "
"with time running left-to-right then top-to-bottom (row-major).\n"
"- Each frame has its timestamp burned into the top-left corner, e.g. "
'"012.50s". Use that printed timestamp (not the tile position) when you '
"choose start/end times; boundaries should land on or near a printed "
"timestamp.\n"
"- Frames continue across grids: an action may span the end of one sheet "
"and the start of the next, so do not place a boundary just because a new "
"image begins.\n\n"
)
# Appended to every describe (and segment) prompt. A visual, causal definition
# of where one event ends and the next begins — adapted from macrodata/refiner —
# to sharpen cut points while the existing prompt keeps owning the imperative
# phrasing.
_CAUSAL_BOUNDARY_RULES = (
"EVENT BOUNDARIES — where one event ends and the next begins:\n"
"- Start a new event whenever the world state changes: an object becomes "
"held (the gripper closes on it), an object is released (the gripper opens "
"and it stays put), an object reaches a new location, a lid/door/drawer "
"changes open/closed state, a tool starts or stops affecting a surface, or "
"contents visibly move (e.g. poured).\n"
"- If a single action changes the same state gradually and continuously, "
"keep it as ONE event — do not split it.\n"
"- If the same action repeats on different objects or target locations, "
"treat each repetition as a separate event.\n"
"- Do NOT create boundaries for idle time, camera motion, hesitation, or "
"tiny hand adjustments."
)
@dataclass
class PlanSubtasksMemoryModule:
"""Generate subtask spans, plan, and memory rows.
All output is persistent (lives in ``language_persistent``):
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
(snapped to an exact frame).
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
timestamp via :meth:`run_plan_updates` (called by the executor after
the ``interjections`` module completes).
- ``memory`` rows: emitted at each subtask boundary (= subtask start
timestamp from the second subtask onward).
"""
vlm: VlmClient
config: PlanConfig
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
# Task driving every plan-module prompt: canonical episode_task, or a
# video-derived one when it's empty/placeholder (see derive_task_*).
effective_task = self._resolve_effective_task(record)
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
# free-form n_task_rephrasings; the effective task is always emitted
# first so the rotation covers the source-of-truth phrasing.
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
variants: list[str] | None = None
if self.config.task_aug_axes.enabled and effective_task:
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
elif self.config.n_task_rephrasings > 0 and effective_task:
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
if variants is not None:
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
subtask_spans = self._generate_subtasks(record, task=effective_task)
# subtask rows
for span in subtask_spans:
rows.append(
{
"role": "assistant",
"content": span["text"],
"style": "subtask",
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
"tool_calls": None,
}
)
# Plan rows at every subtask boundary (incl. t=0). The plan is a
# numbered list of still-todo subtasks, so re-emitting at each
# boundary makes it shrink as work progresses — ${plan} at frame t is
# exactly what's left to do.
if self.config.emit_plan:
for span in subtask_spans:
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
plan_text = self._generate_plan(
record, subtask_spans, refresh_t=boundary_t, task=effective_task
)
if plan_text is not None:
rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": float(boundary_t),
"tool_calls": None,
}
)
# memory rows at every subtask boundary except the very first start;
# skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only).
prior_memory = ""
memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else []
for i, span in memory_boundaries:
completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
if mem_text:
ts = snap_to_frame(span["start"], record.frame_timestamps)
rows.append(
{
"role": "assistant",
"content": mem_text,
"style": "memory",
"timestamp": ts,
"tool_calls": None,
}
)
prior_memory = mem_text
staging.write("plan", rows)
# ------------------------------------------------------------------
# Task derivation + rephrasings
# ------------------------------------------------------------------
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
{
"debug",
"test",
"tbd",
"todo",
"n/a",
"na",
"untitled",
"unnamed",
"default",
"placeholder",
}
)
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
"""Decide which task string drives the ``plan`` module for this episode.
Returns the user-supplied ``record.episode_task`` unless
``derive_task_from_video`` says otherwise (see config docstring).
Falls back gracefully to the canonical task if video derivation
fails.
"""
canonical = (record.episode_task or "").strip()
mode = (self.config.derive_task_from_video or "off").strip().lower()
if mode == "always":
derived = self._derive_task_from_video(record)
return derived or canonical
if mode == "if_short" and self._task_seems_bad(canonical):
derived = self._derive_task_from_video(record)
if derived:
return derived
return canonical
def _task_seems_bad(self, task: str) -> bool:
if not task:
return True
if len(task.split()) < int(self.config.derive_task_min_words):
return True
return task.lower() in self._PLACEHOLDER_TASKS
@staticmethod
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
seen: set[str] = set()
rows: list[dict[str, Any]] = []
for phrasing in phrasings:
key = phrasing.strip()
if not key or key in seen:
continue
seen.add(key)
rows.append(
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
)
return rows
# ------------------------------------------------------------------
# VLM call helpers — every plan-module prompt follows the same shape:
# build messages → single VLM call → pull a named field.
# ------------------------------------------------------------------
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
"""Run a single VLM call and return ``result[field]`` or ``None``.
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
dance every prompt-call site needs.
"""
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict):
return result.get(field)
return None
@staticmethod
def _text_message(text: str) -> list[dict[str, Any]]:
"""One-shot text-only user message wrapped for ``generate_json``."""
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
def _video_message(
self,
record: EpisodeRecord,
prompt: str,
window: tuple[float, float] | None = None,
) -> list[dict[str, Any]]:
"""User message combining the (optionally windowed) contact sheets with ``prompt``.
The prompt is always prefixed with a short explanation of how to read
the timestamped grids, so the model treats them as one ordered
sequence of frames rather than unrelated images.
"""
prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
"""Ask the VLM "what is this video about" with no task hint at all."""
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
return text.strip() if isinstance(text, str) and text.strip() else None
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
"""Generate ``n`` text-only paraphrases of ``base_task``."""
if n <= 0 or not base_task:
return []
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
if not isinstance(raw, list):
return []
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
return [s for s in out if s][:n]
# ------------------------------------------------------------------
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
# ------------------------------------------------------------------
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
"""One VLM call → variants along the 5-axis taxonomy.
Variants from all axes are flattened into a single list (the
downstream pipeline doesn't need to know about the per-axis
bucketing — every variant becomes a ``task_aug`` row). Order
is preserved for reproducibility: synonym_paraphrase first,
then omit_arm, then omit_orientation, then omit_grasp_method,
then combined_omissions.
"""
if not base_task:
return []
prompt = load_prompt("plan_task_aug_axes").format(
base_task=base_task,
n_synonym=axes_cfg.synonym_paraphrase,
n_omit_arm=axes_cfg.omit_arm,
n_omit_orientation=axes_cfg.omit_orientation,
n_omit_grasp_method=axes_cfg.omit_grasp_method,
n_combined=axes_cfg.combined_omissions,
)
result = self.vlm.generate_json([self._text_message(prompt)])[0]
if not isinstance(result, dict):
return []
ordered_axes = (
"synonym_paraphrase",
"omit_arm",
"omit_orientation",
"omit_grasp_method",
"combined_omissions",
)
flat: list[str] = []
seen: set[str] = set()
for axis in ordered_axes:
entries = result.get(axis)
if not isinstance(entries, list):
continue
for item in entries:
if not isinstance(item, str):
continue
key = item.strip().strip('"').strip("'")
if not key or key in seen:
continue
seen.add(key)
flat.append(key)
return flat
def _episode_video_block(
self, record: EpisodeRecord, window: tuple[float, float] | None = None
) -> list[dict[str, Any]]:
"""Timestamped contact sheets for the describe / segmentation prompts.
Always renders the (optionally windowed) episode as contact sheets:
frames sampled at ``frames_per_second`` and packed into timestamped
JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole
episodes that exceed it are windowed upstream in
:meth:`_generate_subtasks` so each call stays within budget while the
full episode keeps its sampling density.
When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE
(``ts - w0``) to match the window-relative time frame the
segmentation prompt works in (spans are offset back to absolute time
afterwards).
"""
if not record.frame_timestamps:
return []
if window is not None:
w0, w1 = float(window[0]), float(window[1])
dur = max(0.0, w1 - w0)
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_frames_per_prompt)
if n <= 1 or dur <= 0.0:
timestamps = [0.5 * (w0 + w1)]
else:
step = dur / (n - 1)
timestamps = [w0 + i * step for i in range(n)]
frames = self.frame_provider.frames_at(record, timestamps)
rel = [ts - w0 for ts in timestamps[: len(frames)]]
return self._contact_sheet_blocks(frames, rel)
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_frames_per_prompt)
timestamps = self._uniform_episode_timestamps(record, n)
frames = self.frame_provider.frames_at(record, timestamps)
return self._contact_sheet_blocks(frames, timestamps[: len(frames)])
@staticmethod
def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]:
"""``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly."""
ts = record.frame_timestamps
if n >= len(ts):
return [float(t) for t in ts]
t0, t_last = float(ts[0]), float(ts[-1])
if t_last <= t0 or n <= 1:
return [t0] * max(1, n)
step = (t_last - t0) / (n - 1)
return [t0 + i * step for i in range(n)]
def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]:
"""Build timestamped contact-sheet image blocks from decoded frames."""
return to_contact_sheet_blocks(
frames,
timestamps,
columns=self.config.contact_sheet_columns,
frames_per_sheet=self.config.contact_sheet_frames_per_sheet,
frame_width=self.config.contact_sheet_frame_width,
quality=self.config.contact_sheet_quality,
)
def run_plan_updates(
self,
record: EpisodeRecord,
staging: EpisodeStaging,
interjection_times: Sequence[float],
interjection_texts: Sequence[str] | None = None,
) -> None:
"""Append additional ``plan`` rows at every interjection timestamp.
Plans refresh ONLY on user interjections (event-driven). The
interjection text is forwarded into the prompt so the refreshed plan
reflects the user's correction.
"""
if not self.config.emit_plan:
return
existing = staging.read("plan")
# Pass the last frame timestamp so the final span is closed (else its
# end == start, zero duration, and a refresh inside it is missed).
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
new_rows = list(existing)
texts: list[str | None] = (
[None] * len(interjection_times)
if interjection_texts is None
else [str(t) if t else None for t in interjection_texts]
)
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
t = snap_to_frame(raw_t, record.frame_timestamps)
if t in already_planned:
continue
already_planned.add(t)
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
if plan_text is not None:
new_rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": t,
"tool_calls": None,
}
)
staging.write("plan", new_rows)
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
"""Generate subtask spans, optionally via a multi-call quality chain.
Single call (default): watch video → emit subtask JSON.
Multi-call (opt-in, higher quality, more VLM calls):
1. ``subtask_describe_first`` — a grounding pass that narrates
ONLY what is visible (no JSON commitment to subtasks yet);
its description is injected into the segmentation prompt so
the model segments its own grounded observations instead of
pattern-matching the task text.
2. segmentation — emit subtask JSON (as before).
"""
if record.row_count == 0 or not record.frame_timestamps:
return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
effective_task = task if task is not None else record.episode_task
# ---- Auto-windowing (keeps the full sampling density) --------
# Contact sheets are cheap, but a whole long episode sampled at
# ``frames_per_second`` can still exceed ``max_frames_per_prompt``.
# When it does, split into consecutive windows of exactly that many
# frames (one describe→segment call each, still at the full sampling
# density), then merge + stitch — so an episode of any length is
# covered at full density rather than subsampled into one sparse call.
fps = max(1e-6, float(self.config.frames_per_second))
n_whole = int(round(episode_duration * fps)) + 1
if n_whole > self.config.max_frames_per_prompt:
window_s = self.config.max_frames_per_prompt / fps
return self._generate_subtasks_windowed(record, effective_task, window_s)
# ---- Pass 1 (optional): grounding description ----------------
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
description = self._describe_episode(record, effective_task)
if description:
observation_block = (
"You watched this video and described, chronologically, "
"ONLY what the robot actually does:\n"
f'"""{description}"""\n\n'
"Segment THAT grounded description (cross-checked against "
"the video) into atomic subtasks. Do not introduce any "
"action that is not in your description above.\n\n"
)
# ---- Pass 2: segmentation ------------------------------------
prompt = self._with_causal_rules(
load_prompt("plan_subtasks").format(
episode_task=effective_task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
observation_block=observation_block,
)
)
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
cleaned = self._clean_spans(spans, record)
if not cleaned:
return []
# ---- Full-episode coverage stitch ----------------------------
# The VLM can start after t0 or leave gaps, so frames fall through
# with no active subtask. Always stitch into a contiguous
# [t0, t_last] cover.
cleaned = self._stitch_full_coverage(cleaned, record)
return cleaned
def _generate_subtasks_windowed(
self, record: EpisodeRecord, task: str, window_s: float
) -> list[dict[str, Any]]:
"""Subtask generation in fixed-length windows at constant fps.
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
seconds, runs the describe -> segment chain on each window's own
frames (sampled at ``frames_per_second``), offsets
each window's spans back to absolute episode time, then merges +
stitches into a contiguous whole-episode cover.
"""
t0 = float(record.frame_timestamps[0])
t_last = float(record.frame_timestamps[-1])
all_spans: list[dict[str, Any]] = []
w0 = t0
n_windows = 0
while w0 < t_last - 1e-6:
w1 = min(w0 + window_s, t_last)
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
n_windows += 1
w0 = w1
logger.info(
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
record.episode_index,
n_windows,
window_s,
len(all_spans),
)
# Merge across windows: clamp to the absolute episode, sort, and
# frame-snap to distinct starts (handles any boundary collisions).
cleaned = self._clean_spans(all_spans, record)
if not cleaned:
return []
return self._stitch_full_coverage(cleaned, record)
def _subtasks_for_window(
self, record: EpisodeRecord, task: str, w0: float, w1: float
) -> list[dict[str, Any]]:
"""Run describe -> segment on one ``[w0, w1]`` window.
The model works in window-RELATIVE time ``[0, L]`` (it perceives
the window as a clip starting at 0); spans are offset back to
absolute ``[w0, w1]`` before returning.
"""
window = (w0, w1)
win_len = max(0.0, w1 - w0)
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
description = self._describe_episode(record, task, window=window)
if description:
observation_block = (
"You watched this video clip and described, chronologically, "
"ONLY what the robot actually does:\n"
f'"""{description}"""\n\n'
"Segment THAT grounded description (cross-checked against "
"the clip) into atomic subtasks. Do not introduce any "
"action that is not in your description above.\n\n"
)
prompt = self._with_causal_rules(
load_prompt("plan_subtasks").format(
episode_task=task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{win_len:.3f}",
observation_block=observation_block,
)
)
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
# Window-relative clamp; no frame-snap dedupe yet (done on the
# merged absolute set).
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
if not cleaned:
return []
# Offset window-relative spans back to absolute episode time.
for s in cleaned:
s["start"] = w0 + float(s["start"])
s["end"] = w0 + float(s["end"])
return cleaned
def _stitch_full_coverage(
self, spans: list[dict[str, Any]], record: EpisodeRecord
) -> list[dict[str, Any]]:
"""Make subtask spans tile the full episode with no gaps.
* The first subtask starts at the episode's first frame ``t0``
(any idle / approach before the first labelled action is folded
into it), so every early frame has an active subtask.
* Each subtask's ``end`` is snapped to the next subtask's
``start`` (gaps between spans are closed), and the final
subtask's ``end`` extends to the last frame ``t_last``.
Starts are otherwise left as the (already frame-snapped, distinct)
values the VLM produced — only the FIRST start is pulled
back to ``t0``, which can't collide with a later span because it
was already the earliest. Purely deterministic; runs after the
VLM passes.
"""
if not spans or not record.frame_timestamps:
return spans
t0 = float(record.frame_timestamps[0])
t_last = float(record.frame_timestamps[-1])
spans = sorted(spans, key=lambda s: float(s["start"]))
spans[0]["start"] = t0
for i in range(len(spans) - 1):
spans[i]["end"] = float(spans[i + 1]["start"])
spans[-1]["end"] = t_last
for s in spans:
if float(s["end"]) < float(s["start"]):
s["end"] = float(s["start"])
return spans
@staticmethod
def _with_causal_rules(prompt: str) -> str:
"""Append the causal event-boundary rules to a describe/segment prompt."""
return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}"
def _clean_spans(
self,
spans: Any,
record: EpisodeRecord,
bounds: tuple[float, float] | None = None,
dedupe: bool = True,
) -> list[dict[str, Any]]:
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
``bounds`` overrides the clamp range — pass the window's
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
``None`` to clamp to the whole episode ``[t0, t_last]``.
``dedupe`` runs the frame-snap distinct-start step; skip it for
window-relative spans (frame snapping is done once on the merged,
absolute-time set).
"""
if not spans:
return []
if bounds is not None:
lo, hi = float(bounds[0]), float(bounds[1])
else:
lo = record.frame_timestamps[0]
hi = record.frame_timestamps[-1]
cleaned: list[dict[str, Any]] = []
for span in spans:
try:
start = float(span["start"])
end = float(span["end"])
text = str(span["text"]).strip()
except (KeyError, ValueError, TypeError):
continue
start = max(lo, min(start, hi))
end = max(lo, min(end, hi))
if end < start:
start, end = end, start
if not text:
continue
cleaned.append({"text": text, "start": start, "end": end})
cleaned.sort(key=lambda s: s["start"])
if dedupe:
return self._dedupe_starts_to_distinct_frames(cleaned, record)
return cleaned
def _describe_episode(
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
) -> str:
"""Grounding pass: free-form chronological description of the (windowed) video."""
prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task))
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
return text.strip() if isinstance(text, str) and text.strip() else ""
@staticmethod
def _dedupe_starts_to_distinct_frames(
spans: list[dict[str, Any]], record: EpisodeRecord
) -> list[dict[str, Any]]:
"""Bump same-frame subtask starts onto distinct frames.
Two consecutive VLM spans whose ``start`` rounds to the same
source frame (after :func:`snap_to_frame`) would otherwise emit
two ``style=subtask`` rows at the identical persistent
timestamp. The training-time renderer's ``active_at(t,
style=subtask)`` resolver can't disambiguate that and raises
``Ambiguous resolver for style='subtask'``.
Walk the (sorted-by-start) spans, snap each to its frame, and
if the snapped frame is already taken push the span onto the
next unused frame so both subtasks survive on distinct
timestamps. If the episode ends before a free frame is found,
the trailing span is dropped with a warning — better than
poisoning the render.
"""
if not spans:
return spans
frames = record.frame_timestamps
if not frames:
return spans
used: set[float] = set()
out: list[dict[str, Any]] = []
for span in spans:
ts = snap_to_frame(span["start"], frames)
if ts in used:
next_ts = next((f for f in frames if f > ts and f not in used), None)
if next_ts is None:
logger.warning(
"episode %d: subtask %r snapped to occupied frame "
"%.3f and no free later frame exists — dropping",
record.episode_index,
span.get("text"),
ts,
)
continue
ts = next_ts
used.add(ts)
new_span = {**span, "start": ts}
if float(new_span.get("end", ts)) < ts:
new_span["end"] = ts
out.append(new_span)
return out
def _generate_plan(
self,
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
subtask_spans: Sequence[dict[str, Any]],
*,
refresh_t: float | None = None,
interjection: str | None = None, # noqa: ARG002
task: str | None = None, # noqa: ARG002
) -> str | None:
"""Deterministic plan = numbered list of *still-todo* subtasks.
No VLM call: a plain numbered list keeps the plan aligned with the
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
cost a round-trip per episode/refresh and could diverge).
1. <subtask 1>
2. <subtask 2>
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
interjections, and ``run_episode`` at each boundary), only subtasks
starting at or after ``refresh_t`` are included — so it always
describes what's left.
"""
if not subtask_spans:
return None
remaining = [
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
]
if not remaining:
# Past the last subtask boundary on a late refresh — nothing
# left to plan; emit None so the caller skips the row.
return None
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
def _generate_memory(
self,
record: EpisodeRecord,
prior_memory: str,
completed: str,
remaining: Sequence[str],
*,
task: str | None = None,
) -> str:
prompt = load_prompt("plan_memory").format(
episode_task=(task if task is not None else record.episode_task),
prior_memory=prior_memory or "(none)",
completed_subtask=completed,
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
)
memory = self._vlm_field(self._text_message(prompt), "memory")
return memory.strip() if isinstance(memory, str) else ""
@@ -1,33 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompt templates loaded as plain text.
One file per use site. Templates use ``str.format(**vars)`` substitution; we
intentionally avoid jinja2 here so the templates remain inspectable in
plain editors and roundtrip cleanly through ``ruff format``.
"""
from __future__ import annotations
from pathlib import Path
_DIR = Path(__file__).parent
def load(name: str) -> str:
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
path = _DIR / f"{name}.txt"
return path.read_text(encoding="utf-8")
@@ -1,12 +0,0 @@
The user just asked the robot: "{episode_task}".
Generate a short verbal acknowledgement the robot would speak back before
beginning the task. Style: compact, confident, friendly.
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
"OK, starting with the sponge.", "Got it.".
Prefer very short replies: "Got it.", "On it.", "OK."
Output strictly valid JSON:
{{ "text": "<the spoken acknowledgement>" }}
@@ -1,46 +0,0 @@
You are generating training data for a Hi Robot-style hierarchical
robot policy. The robot in this demonstration has ALREADY executed
every step shown in the video — we cannot retroactively change the
action stream. To keep training data consistent with the video, the
"interjection" must align with what the robot is *about to do next* in
the demonstration, framed as a natural mid-task user request.
The episode's overall task: "{episode_task}".
The images above show roughly {window_seconds:.1f} seconds straddling a
subtask boundary in the demonstration:
- Subtask the robot just finished: "{prev_subtask}"
- Subtask the robot is about to start: "{next_subtask}"
- Time into episode: {timestamp:.2f}s
Write ONE compact interjection the user would naturally say at this
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
Keep it like a mid-task coaching cue, not a full instruction paragraph.
Also write the robot's compact verbal acknowledgement.
Hard rules:
- The interjection MUST be consistent with the next subtask. The user
cannot ask for something different from what the robot then does in
the video. If you're tempted to say "actually skip X" or "do Y
instead", DO NOT — those would contradict the demonstration.
- The interjection must reference an object, location, or action that
is plausible given the visible scene and the next subtask text.
- One short phrase or sentence each. Conversational, not robotic.
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
- Keep robot speech very short: "OK.", "On it.", "Doing that."
Style examples (vary the phrasing — don't reuse these verbatim):
- "Now go ahead and {next_subtask}."
- "Great, can you {next_subtask} next?"
- "{next_subtask}, please."
- "Before you continue, please {next_subtask}."
- "Looking good — {next_subtask} now."
- "Okay, {next_subtask}."
Output strictly valid JSON:
{{
"interjection": "<short cue from the user, asking for the next subtask>",
"speech": "<short robot acknowledgement>"
}}
@@ -1,36 +0,0 @@
You are updating the robot's compressed semantic memory at the boundary of
a completed subtask.
Reference (verbatim from MEM, Torne 2026):
"Remove or compress information in the language memory whenever
appropriate. Keep ONLY the minimal set of relevant information for future
task execution. Specific object attributes (colors, precise quantities of
each item) get discarded when their details won't affect subsequent
actions. Functional outcomes (where items went, how many) are preserved."
Episode task: "{episode_task}"
Previous memory: {prior_memory}
Just-completed subtask: "{completed_subtask}"
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
robot has accomplished so far — the running story it would tell itself.
Authoring rules:
- First person, past tense. Every sentence starts with "I": "I picked
up...", "I opened...", "I moved to...".
- One or two short sentences. Extend the previous memory with the
just-completed subtask; do not rewrite it from scratch.
- Keep WHAT happened (functional outcomes — where items went, how many),
drop HOW (grasp details, motions).
- Compress completed steps and drop object attributes (colors, exact
counts) once they no longer affect the remaining subtasks.
Example (MEM, Torne 2026):
Before: "I prepared the pot and got the potatoes, milk, and butter. I
moved to the drawer."
After: "I prepared the pot and got the ingredients. I opened the
drawer with the masher."
Output strictly valid JSON:
{{ "memory": "<one or two short first-person past-tense sentences>" }}
@@ -1,27 +0,0 @@
You are watching a teleoperated robot demonstration from a single
camera. The user asked the robot to: "{episode_task}"
This is an OBSERVATION pass. Watch the entire clip and describe, in
chronological order, ONLY what the robot physically does — the concrete
motions, approaches, contacts, grasps, releases, and relocations you can
actually SEE in the frames.
Hard rules:
- Describe only motion visible in the video. Do NOT use the task
instruction to guess steps that aren't shown. The instruction is the
goal; the video is ground truth.
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
the single field below. Just narrate what happens.
- Give an approximate timestamp (in seconds) for each distinct event,
e.g. "0.0-1.4s: the base drives forward toward the stove".
- Do NOT invent objects, grasps, destinations, or steps. If the robot
only does one thing (e.g. it just navigates and the clip ends), say
exactly that and nothing more.
- Be concrete and literal. "the gripper closes on the mug" — not "the
robot prepares to make coffee".
Output strictly valid JSON:
{{
"description": "<chronological, timestamped description of ONLY what is visible>"
}}
@@ -1,112 +0,0 @@
You are labeling a teleoperated robot demonstration.
The user originally asked: "{episode_task}"
You are shown the entire demonstration as a single video. Watch the
whole clip, then segment it into a list of consecutive atomic subtasks
the robot performs.
{observation_block}GROUNDING — read this first, it overrides everything below:
- Label ONLY what the robot actually does in the video. Every subtask
you emit must correspond to motion you can SEE in specific frames.
- Do NOT invent, anticipate, or pad. If the robot only does one thing
(e.g. it just navigates to a location and the clip ends), emit
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
subtasks than the ceiling is not just allowed, it is expected for
short / atomic demonstrations. One correct subtask is far better
than several invented ones.
- If the video does not clearly show the action implied by the task,
describe what you actually see — do NOT fabricate the task's steps
from the instruction text. The instruction tells you the goal; the
VIDEO is the ground truth for what happened.
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
- Each subtask = one COMPOSITE atomic skill the low-level policy can
execute end-to-end. A "skill" bundles its own approach motion with
its terminal action — do NOT split the approach off as its own
subtask. The whole-arm policy already learns to reach as part of
every manipulation primitive.
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
these verbs (extend only when none fits):
pick up <obj> — approach + grasp + lift in one subtask
put <obj> on/in <loc> — transport + release in one subtask
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
push <obj> — contact + linear shove
pull <obj> — contact + linear retract
turn <knob/dial/handle> — rotary actuation
press <button> — single-press contact
open <drawer/door/lid> — full open motion
close <drawer/door/lid> — full close motion
pour <src> into <dst> — tilt + flow
insert <obj> into <slot>— alignment + push-fit
go to <loc> — ONLY when no grasp / actuation follows
(e.g. a pure relocation between phases).
If the next subtask grasps something at
that location, drop "go to ..." and just
write "pick up ..." instead.
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
as standalone subtasks; fold them into the parent composite:
"move to X" → fold into "pick up X" (or whatever follows)
"reach for X" → fold into "pick up X"
"grasp X" → fold into "pick up X"
"lift X" → fold into "pick up X" (or "put X on Y" if it's
the transport phase of a place)
"release X" → fold into "put X on Y" (or "place X in Y")
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
detail (which hand, which grasp point) ONLY when it is needed to
disambiguate. Every subtask must begin with one of the verbs
above (no leading nouns, no "then", no "first").
- NEVER use third person. Never write "the robot", "the arm", "the
gripper moves", "it picks up" — the robot is implied. Command it,
do not describe it.
- Use the exact object nouns from the task above. If the task says
"cube", every subtask says "cube" — never switch to "block". If it
says "box", never switch to "bin"/"container". Keep vocabulary
consistent across the whole episode.
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
"turn red knob", "press start button", "go to sink".
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
must be folded into "pick up blue cube"); "the robot arm moves
towards the blue cube" (third person, too long); "carefully pick
up the cube" (adverb, article); "release the yellow block"
("block" when the task said "cube", and "release" must be folded
into a "put"/"place" subtask).
- Subtasks are non-overlapping and cover the full episode in order.
Choose the cut points yourself based on what you see in the video
(gripper open/close events, contact, regrasps, transitions).
- Each subtask spans at least {min_subtask_seconds} seconds. If a
candidate span would be shorter, merge it into its neighbour
rather than emitting it.
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
are preferred over many micro-steps.
- Every subtask's [start_time, end_time] must lie within
[0.0, {episode_duration}] seconds.
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
fires ONLY on the spatial situation it names; it must not change how you
label any other situation):
- STACK vs PUT: if an object is placed ON TOP OF another specific object
(not on a flat table / shelf / counter), use "stack ... on ...", not
"put". "stack blue book on green book", NOT "put blue book on table".
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
receptacle (push-fit), use "insert ... into ...", not "put".
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
on the object and the object moves WITH the hand, it is "pick up" /
"retrieve" (object leaves its location). If the gripper OPENS and the
object stays where the hand left it, it is "put" / "place" (object
arrives at a location). Decide by which way the object moves, not by
where the hand ends up.
- POUR vs PUT: only use "pour" when the source is tilted and contents
flow out; moving a full container without tilting is "put"/"place".
Output strictly valid JSON of shape:
{{
"subtasks": [
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
...
]
}}
@@ -1,67 +0,0 @@
You are generating structured augmentations of a robot task instruction
for training a language-conditioned policy. Unlike free-form rephrasing,
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
a specific element of the task while preserving its meaning.
Original task: "{base_task}"
Produce variants along five named axes. Each axis has a target count.
The whole batch should expose the policy to maximum linguistic diversity
WITHOUT changing what the robot is supposed to do.
Axes and target counts:
synonym_paraphrase ({n_synonym}):
Different wording / verbs / sentence structure. ALL information
from the original task is preserved — same object, same arm
specification if present, same orientation if present, same grasp
if present.
omit_arm ({n_omit_arm}):
Drop the left/right/both arm specification from the task. Skip
entirely (emit 0 entries) if the original task does NOT mention an
arm. Do not invent an arm specification just to omit it.
omit_orientation ({n_omit_orientation}):
Drop orientation cues (upright, sideways, facing the user,
long-edge-first, etc.). Skip entirely if no orientation cue is
present in the original task.
omit_grasp_method ({n_omit_grasp_method}):
Drop the grip / grasp method specification (pinch, wrap, hold by
the rim, etc.). Skip entirely if no grasp method is mentioned.
combined_omissions ({n_combined}):
Combine TWO of the above omissions simultaneously (e.g. drop both
arm and orientation). Skip entirely if fewer than two of (arm,
orientation, grasp_method) appear in the original task.
Hard rules:
- Each variant MUST preserve the core action, the target object, AND
the goal / destination. Do not change which object is involved, where
it goes, or the high-level action. "Navigate to the stove" may become
"go to the stove" or "head over to the stove" — it must NEVER become
"wander around the kitchen", "explore the room", or anything that
drops or generalises the stove destination. If you cannot vary the
wording without changing the goal, emit fewer variants.
- Only the FIVE listed elements (wording, arm, orientation, grasp
method, or a combination) may be varied or omitted. The verb's
meaning, the object, and the destination are fixed.
- Each variant is plain prose, no markdown, no quotes, no list numbers.
- Each variant must be DISTINCT from every other variant in the entire
output, both within and across axes. Near-duplicates are not allowed.
- If an axis cannot reach its target count because the original task
lacks the omittable element, emit fewer entries — do NOT pad the
axis with paraphrases that belong to a different axis.
- Variants should not all start with verbs — vary sentence structure
(some imperative, some polite request, some question).
Output strictly valid JSON of shape:
{{
"synonym_paraphrase": ["<v1>", "<v2>", ...],
"omit_arm": ["<v1>", "<v2>", ...],
"omit_orientation": ["<v1>", ...],
"omit_grasp_method": ["<v1>", ...],
"combined_omissions": ["<v1>", ...]
}}
@@ -1,32 +0,0 @@
You are generating training data for a Hi Robot-style policy. We need
{n} alternative phrasings of the same robot task so the policy sees
diverse user prompts during training instead of the same canonical
string repeated every frame.
Original task:
"{base_task}"
Generate exactly {n} alternative phrasings of the same task. Vary:
- formality (casual / polite / curt)
- verbosity (mostly short imperative; occasional polite request)
- word choice (synonyms, different verbs)
- sentence structure (imperative / question / suggestion)
Hard rules:
- Each phrasing MUST preserve the exact meaning of the original task.
Do not change which object is involved, the destination, or the
action. Do not add extra steps. Do not invent new objects.
- Each phrasing must be a short phrase or sentence, plain prose, no
markdown, no quotes, no list numbers.
- Phrasings must be distinct — no near-duplicates.
- Output exactly {n} entries.
Output strictly valid JSON:
{{
"rephrasings": [
"<phrasing 1>",
"<phrasing 2>",
...
]
}}
@@ -1,17 +0,0 @@
The video above shows a robot manipulation episode in full. Look at
the entire video and describe in ONE concise sentence what the robot
is doing.
Rules:
- One sentence, in natural English, like a user instruction.
- Capture the goal of the demonstration, not low-level motions.
Example: "place the yellow cube into the red bin" — not "move the
end-effector down 5cm and close the gripper".
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
- Do not invent objects or actions that aren't visible.
- Do not output anything other than the JSON object below.
Output strictly valid JSON:
{{
"task": "<single concise sentence describing what the robot does in this video>"
}}
@@ -1,32 +0,0 @@
You are generating a frame-grounded visual question/answer pair for
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
Policies — both train policies on grounded features such as bounding box
pixel coordinates, keypoints, counts, attributes, and spatial relations.
The frame shows a robot working on: "{episode_task}".
Question types and the EXACT answer JSON shape required for each:
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}}, ...]}}
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
ECoT example: "a white cup [124, 25, 176, 113]".
keypoint => {{"label": "<point>", "point_format": "xy",
"point": [x, y]}}
count => {{"label": "<obj>", "count": <int>,
"note": "<optional short note>"}}
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
"value": "<observed value>"}}
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
"above|below|near>", "object": "<obj>"}}
Generate a question of type "{question_type}". Output strictly valid JSON:
{{
"question": "<short, frame-grounded question>",
"answer": <object whose shape matches the schema above>
}}
@@ -1,216 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datatrove-shaped reader.
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
episode containing:
- ``episode_index``: int
- ``frame_timestamps``: tuple[float, ...]
- ``frame_indices``: tuple[int, ...]
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
- ``data_path``: pathlib.Path of the source parquet shard
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
This shape lets each module operate per-episode without loading all parquet
rows into memory at once.
"""
from __future__ import annotations
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import load_tasks
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
@dataclass
class EpisodeRecord:
"""Per-episode record yielded by the reader."""
episode_index: int
episode_task: str
frame_timestamps: tuple[float, ...]
frame_indices: tuple[int, ...]
data_path: Path
row_offset: int # row offset within the parquet file where this episode starts
row_count: int # number of rows for this episode
# Memoized parquet slice — populated on first ``frames_df()`` call so
# repeat queries from different modules don't re-read the whole shard.
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
def frames_df(self): # type: ignore[no-untyped-def]
"""Lazy-load the pandas slice for this episode (memoized)."""
if self._frames_df_cache is None:
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
table = pq.read_table(self.data_path)
df: pd.DataFrame = table.to_pandas()
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
drop=True
)
return self._frames_df_cache
def reconstruct_subtask_spans(
rows: Sequence[dict[str, Any]],
*,
episode_end_t: float | None = None,
) -> list[dict[str, Any]]:
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
Each span's ``end`` is the next span's ``start``. The final span's
``end`` defaults to its own ``start`` (zero-duration) — pass
``episode_end_t`` to extend it to the episode's last frame instead,
which is what downstream consumers (memory, interjection boundary
selection) expect.
Used by the ``plan`` module (plan-update pass) and the
``interjections`` module (interjection anchoring), which both need the
same span shape.
"""
sorted_rows = sorted(
(r for r in rows if r.get("style") == "subtask"),
key=lambda r: float(r["timestamp"]),
)
spans: list[dict[str, Any]] = []
for r in sorted_rows:
t = float(r["timestamp"])
if spans:
spans[-1]["end"] = t
spans.append({"text": r.get("content") or "", "start": t, "end": t})
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
spans[-1]["end"] = float(episode_end_t)
return spans
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
"""Snap an arbitrary float to the nearest exact source frame timestamp.
Modules use this when emitting event-style rows so the row's
timestamp matches a real parquet frame: event rows must land on an
exact frame, otherwise the per-frame event lookup the writer does
would never match them.
"""
if not frame_timestamps:
return float(t)
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
return float(nearest)
def _load_tasks_lookup(root: Path) -> dict[int, str]:
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
Returns an empty dict when the file is absent — the task description is
derived later from the video if needed. Reuses the library-level
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
frame indexed by task string with a ``task_index`` column.
"""
if not (root / DEFAULT_TASKS_PATH).exists():
return {}
tasks = load_tasks(root)
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
Episodes are yielded in ascending ``episode_index`` order. The reader does
not assume a specific chunk/file layout: it scans every ``*.parquet``
under ``data/`` and groups by ``episode_index``.
"""
tasks = _load_tasks_lookup(root)
data_dir = root / "data"
parquet_files = sorted(data_dir.rglob("*.parquet"))
only_set = set(only_episodes) if only_episodes is not None else None
for path in parquet_files:
yield from _iter_one_path(path, tasks, only_set)
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
table = pq.read_table(path)
names = table.column_names
if "episode_index" not in names:
return
episode_col = table.column("episode_index").to_pylist()
timestamp_col = (
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
)
frame_col = (
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
)
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
def _build(
ep: int,
start: int,
end: int,
task_idx: int | None,
ts_buf: list[float],
fi_buf: list[int],
) -> EpisodeRecord | None:
if only_set is not None and ep not in only_set:
return None
task = tasks.get(task_idx, "") if task_idx is not None else ""
return EpisodeRecord(
episode_index=ep,
episode_task=task,
frame_timestamps=tuple(ts_buf),
frame_indices=tuple(fi_buf),
data_path=path,
row_offset=start,
row_count=end - start,
)
cur_ep: int | None = None
start_offset = 0
ts_buf: list[float] = []
fi_buf: list[int] = []
cur_task_idx: int | None = None
for i, ep in enumerate(episode_col):
if cur_ep is None:
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
continue
if ep != cur_ep:
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
else:
ts_buf.append(timestamp_col[i])
fi_buf.append(frame_col[i])
if cur_ep is not None:
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
@@ -1,92 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Per-episode staging.
Each module writes its raw output as a JSONL file under
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
staging tree and partitions rows into the two language columns.
JSONL is preferred over parquet here because the staging artifact is meant to
be human-inspectable, easy to diff between prompt iterations, and trivially
appended to. The final dataset format is parquet; staging is just an
intermediate.
"""
from __future__ import annotations
import json
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
ModuleName = str
_MODULES: tuple[ModuleName, ...] = (
"plan",
"interjections",
"vqa",
)
@dataclass
class EpisodeStaging:
"""Filesystem layout for a single episode's staged module outputs."""
root: Path
episode_index: int
@property
def episode_dir(self) -> Path:
return self.root / f"episode_{self.episode_index:06d}"
def path_for(self, module: ModuleName) -> Path:
if module not in _MODULES:
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
return self.episode_dir / f"{module}.jsonl"
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
path = self.path_for(module)
path.parent.mkdir(parents=True, exist_ok=True)
# Atomic replace: a crash mid-write would otherwise leave a
# half-written JSONL file that ``read()`` would then fail to
# parse. Write to a sibling .tmp and rename so the target path
# only ever points at a complete file.
tmp_path = path.with_suffix(path.suffix + ".tmp")
with tmp_path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
f.write("\n")
tmp_path.replace(path)
return path
def read(self, module: ModuleName) -> list[dict[str, Any]]:
path = self.path_for(module)
if not path.exists():
return []
out: list[dict[str, Any]] = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
out.append(json.loads(line))
return out
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
return {m: self.read(m) for m in _MODULES}
def has(self, module: ModuleName) -> bool:
return self.path_for(module).exists()
@@ -1,332 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre-write validation against staged outputs.
Runs after all three modules have written their per-episode artifacts but
*before* the writer rewrites parquet shards. The validator never touches
parquet; it only inspects the staging tree and the source frame timestamps
exposed by :class:`EpisodeRecord`.
Checks (per the plan's "Intermediate staging and validation" section):
- exact timestamp alignment against source frame timestamps
- no orphan speech / interjection pairs
- plan / memory emission consistency (events have a paired persistent row)
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
attribute / spatial)
- every row maps to its correct column under :func:`column_for_style`
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import (
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
column_for_style,
is_view_dependent_style,
validate_camera_field,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
@dataclass
class ValidationReport:
"""Outcome of one validation pass across all episodes."""
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
episodes_checked: int = 0
@property
def ok(self) -> bool:
return not self.errors
def add_error(self, message: str) -> None:
self.errors.append(message)
def add_warning(self, message: str) -> None:
self.warnings.append(message)
def summary(self) -> str:
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
def classify_vqa_answer(payload: Any) -> str | None:
"""Best-effort classification of a VQA answer payload to a question type."""
if not isinstance(payload, dict):
return None
keys = set(payload.keys())
for kind, required in VQA_ANSWER_SHAPES.items():
if required.issubset(keys):
return kind
return None
@dataclass
class StagingValidator:
"""Walks the staging tree and produces a :class:`ValidationReport`."""
timestamp_atol: float = 0.0 # exact-match by default
dataset_camera_keys: tuple[str, ...] | None = None
"""Known ``observation.images.*`` keys on the dataset. When set, the
validator additionally enforces that every view-dependent row's
``camera`` field references one of these keys. Pass ``None`` (default)
to skip that cross-check (e.g. in unit tests with no real dataset)."""
def validate(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
) -> ValidationReport:
report = ValidationReport()
for record in records:
self._validate_episode(record, staging_dir, report)
report.episodes_checked += 1
return report
def _validate_episode(
self,
record: EpisodeRecord,
staging_dir: Path,
report: ValidationReport,
) -> None:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged = staging.read_all()
all_rows: list[dict[str, Any]] = []
for module_name, rows in staged.items():
for row in rows:
row = {**row, "_module": module_name}
all_rows.append(row)
frame_ts = set(record.frame_timestamps)
events: list[dict[str, Any]] = []
persistent: list[dict[str, Any]] = []
for row in all_rows:
self._check_column_routing(row, report, record.episode_index)
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
# ``_check_column_routing`` already recorded any unknown-style error;
# don't let the same ``column_for_style`` lookup raise here uncaught.
try:
column = column_for_style(row.get("style"))
except ValueError:
continue
if column == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
events.append(row)
for row in events:
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
self._check_speech_interjection_pairs(events, report, record.episode_index)
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
self._check_vqa_json(events, report, record.episode_index)
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
def _check_camera_field(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
dataset_camera_keys: Sequence[str] | None,
) -> None:
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
style = row.get("style")
camera = row.get("camera")
try:
validate_camera_field(style, camera)
except ValueError as exc:
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
return
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
report.add_error(
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
)
def _check_vqa_uniqueness_per_frame_camera(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
counts: dict[tuple[float, str, str], int] = {}
for row in events:
if row.get("style") != "vqa":
continue
ts = row.get("timestamp")
camera = row.get("camera")
role = row.get("role")
if ts is None or camera is None or role is None:
continue # other validators flag these
key = (float(ts), str(camera), str(role))
counts[key] = counts.get(key, 0) + 1
for (ts, camera, role), n in counts.items():
if n > 1:
report.add_error(
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
)
def _check_column_routing(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
) -> None:
style = row.get("style")
module = row.get("_module")
try:
target_col = column_for_style(style)
except ValueError:
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
return
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
report.add_error(
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
)
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
report.add_error(
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
)
def _check_event_timestamp_alignment(
self,
row: dict[str, Any],
frame_ts: set[float],
report: ValidationReport,
episode_index: int,
) -> None:
ts = row.get("timestamp")
if ts is None:
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
return
if self.timestamp_atol == 0.0:
if float(ts) not in frame_ts:
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
)
else:
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
)
def _check_speech_interjection_pairs(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
speech_ts: dict[float, int] = {}
interjection_ts: dict[float, int] = {}
for row in events:
ts = row.get("timestamp")
if ts is None:
continue
ts_f = float(ts)
if row.get("style") is None and row.get("role") == "assistant":
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
if row.get("style") == "interjection":
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
for ts in interjection_ts:
if ts not in speech_ts:
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
def _check_plan_memory_consistency(
self,
persistent: Sequence[dict[str, Any]],
events: Sequence[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
interjection_ts = sorted(
{
float(r["timestamp"])
for r in events
if r.get("style") == "interjection" and r.get("timestamp") is not None
}
)
if persistent and not plan_ts:
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
# every interjection should have a same-timestamp plan refresh
for ts in interjection_ts:
if ts not in set(plan_ts):
report.add_error(
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
)
# memory should be emitted at subtask boundaries (subset relation)
if memory_ts and subtask_ts:
mem_set = set(memory_ts)
sub_set = set(subtask_ts)
stray = sorted(mem_set - sub_set)
if stray:
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
def _check_vqa_json(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
for row in events:
if row.get("style") != "vqa" or row.get("role") != "assistant":
continue
content = row.get("content")
if content is None:
report.add_error(
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
)
continue
try:
payload = json.loads(content)
except (TypeError, ValueError) as exc:
report.add_error(
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
)
continue
shape = classify_vqa_answer(payload)
if shape is None:
report.add_error(
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
)
@@ -1,617 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared Qwen-VL client.
The pipeline uses a single shared VLM across modules. vLLM is preferred when
available (high throughput, JSON-guided decoding); transformers is the
fallback. A ``stub`` backend is used for unit tests so fixtures never call
into a real model.
The client speaks one method, :meth:`VlmClient.generate_json`, which:
- accepts a list of OpenAI/HF-style multimodal messages,
- requests JSON output from the server,
- batches requests transparently,
- and reprompts once on a JSON parse failure with an inline correction
message before raising.
"""
from __future__ import annotations
import atexit
import base64
import io
import json
import os
import shlex
import signal
import subprocess
import sys
import threading
import time
import urllib.request
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Protocol
from .config import VlmConfig
class VlmClient(Protocol):
"""Protocol every backend must implement."""
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
"""Generate one JSON-decoded response per messages list."""
@dataclass
class StubVlmClient:
"""Deterministic stub used in unit tests.
A test passes a callable that maps the *last user message text* (or, if
that is empty, the full message list) to a JSON-serializable response.
"""
responder: Callable[[Sequence[dict[str, Any]]], Any]
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
return [self.responder(list(messages)) for messages in messages_batch]
def _strip_to_json(text: str) -> Any:
text = text.strip()
# Strip <think>...</think> blocks (Qwen3 Thinking style)
while "<think>" in text and "</think>" in text:
start = text.find("<think>")
end = text.find("</think>", start) + len("</think>")
text = (text[:start] + text[end:]).strip()
# Strip ```json ... ``` fences from chat-tuned backbones
if text.startswith("```"):
first = text.find("\n")
last = text.rfind("```")
if first != -1 and last != -1 and last > first:
text = text[first + 1 : last].strip()
try:
return json.loads(text)
except (ValueError, json.JSONDecodeError):
pass
# Fall back to extracting the first balanced {...} block.
obj_text = _extract_first_json_object(text)
if obj_text is None:
raise json.JSONDecodeError("No JSON object found", text, 0)
return json.loads(obj_text)
def _extract_first_json_object(text: str) -> str | None:
"""Return the first balanced ``{...}`` substring, ignoring braces in
string literals. Returns ``None`` if no balanced block is found."""
start = text.find("{")
if start < 0:
return None
depth = 0
in_string = False
escape = False
for i in range(start, len(text)):
ch = text[i]
if escape:
escape = False
continue
if ch == "\\":
escape = True
continue
# Note: ``escape`` is always False here — the ``if escape`` branch
# above already handled and reset it.
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
@dataclass
class _GenericTextClient:
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
config: VlmConfig
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
temp = temperature if temperature is not None else self.config.temperature
raw = self.generate_text(messages_batch, max_tok, temp)
out: list[Any] = []
for messages, text in zip(messages_batch, raw, strict=True):
try:
out.append(_strip_to_json(text))
continue
except (ValueError, json.JSONDecodeError):
pass
retry = list(messages) + [
{"role": "assistant", "content": text},
{
"role": "user",
"content": (
"Your previous reply was not valid JSON. "
"Reply with strictly valid JSON, no prose, no fences."
),
},
]
retry_text = self.generate_text([retry], max_tok, temp)[0]
try:
out.append(_strip_to_json(retry_text))
except (ValueError, json.JSONDecodeError):
# After retry: log preview and return None instead of crashing
# the whole pipeline. Modules treat None as "skip".
preview = retry_text.strip().replace("\n", " ")[:200]
print(
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
flush=True,
)
out.append(None)
return out
def make_vlm_client(config: VlmConfig) -> VlmClient:
"""Build the shared VLM client.
Only the ``openai`` backend is supported for now. The shipped workflow
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
optionally auto-spawning the server via ``auto_serve`` /
``serve_command``). The former in-process ``vllm`` / ``transformers``
backends were removed to keep the support surface to the HF Jobs path.
For ``stub``, construct :class:`StubVlmClient` directly with a responder
callable; it is rejected here to make accidental misuse obvious.
"""
if config.backend == "openai":
return _make_openai_client(config)
if config.backend == "stub":
raise ValueError(
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
)
if config.backend in {"vllm", "transformers"}:
raise ValueError(
f"backend={config.backend!r} (in-process local model) is not supported for now — "
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
)
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
def _make_openai_client(config: VlmConfig) -> VlmClient:
"""Backend that talks to any OpenAI-compatible server.
Compatible with ``vllm serve``, ``transformers serve``,
``ktransformers serve``, and hosted endpoints. By default the server
is expected to be already running. Set ``auto_serve=True`` to have
this client spawn one (default: ``transformers serve``), wait until
it's ready, and tear it down on process exit.
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
auto-converted to ``image_url`` data-URLs. Video blocks
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
multi-frame ``video_url`` items where supported.
"""
try:
from openai import OpenAI # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError(
"openai package is required for backend='openai'. Install with `pip install openai`."
) from exc
api_base = config.api_base
api_key = config.api_key
auto_serve = config.auto_serve
api_bases: list[str] = [api_base]
print(
f"[lerobot-annotate] backend=openai model={config.model_id} "
f"api_base={api_base} auto_serve={auto_serve}",
flush=True,
)
if auto_serve:
if config.parallel_servers > 1:
print(
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
flush=True,
)
api_bases = _spawn_parallel_inference_servers(config)
elif _server_is_up(api_base):
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
else:
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
api_base = _spawn_inference_server(config)
api_bases = [api_base]
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
# round-robin counter for parallel mode
rr_counter = {"i": 0}
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
# rejects it with HTTP 422. Send it only when explicitly opted in via
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
rr_lock = threading.Lock()
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
api_messages, mm_kwargs = _to_openai_messages(messages)
kwargs: dict[str, Any] = {
"model": config.model_id,
"messages": api_messages,
"max_tokens": max_tok,
"temperature": temp,
}
extra_body: dict[str, Any] = {}
if send_mm_kwargs and mm_kwargs:
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
if config.chat_template_kwargs:
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
if extra_body:
kwargs["extra_body"] = extra_body
with rr_lock:
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
if len(batch) <= 1 or config.client_concurrency <= 1:
return [_one_call(messages, max_tok, temp) for messages in batch]
# Parallel fan-out — vllm batches these on the server side.
max_workers = min(config.client_concurrency, len(batch))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
return [f.result() for f in futures]
return _GenericTextClient(_gen, config)
def _bind_serve_port(cmd: str, port: int) -> str:
"""Bind a serve command to ``port``: substitute a ``{port}`` placeholder
if present, else append ``--port`` when the command omits it (leaving an
explicit ``--port`` untouched). Shared by the single- and parallel-server
paths so a serve_command never reaches the server with a literal
``{port}``."""
if "{port}" in cmd:
return cmd.replace("{port}", str(port))
if "--port" not in cmd:
return f"{cmd} --port {port}"
return cmd
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
"""Spawn ``config.parallel_servers`` independent vllm replicas.
Each replica:
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
- listens on ``serve_port + i``
- is shut down via the same atexit hook as the single-server path
Returns the list of ``api_base`` URLs the client should round-robin
across.
"""
n = config.parallel_servers
api_bases: list[str] = []
procs: list[subprocess.Popen] = []
ready_events: list[threading.Event] = []
# Multiple readiness signals — uvicorn's own banner is suppressed at
# ``--uvicorn-log-level warning``, so we also accept vllm's own
# "Starting vLLM API server" line and the route-listing line. The
# HTTP probe below is the ultimate fallback.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
# Single lock for all server-stream threads so multibyte chars from
# different servers don't interleave and tear UTF-8 sequences.
print_lock = threading.Lock()
base_cmd = config.serve_command or (
f"vllm serve {shlex.quote(config.model_id)} "
f"--tensor-parallel-size 1 "
f"--max-model-len {config.max_model_len or 32768} "
f"--uvicorn-log-level warning"
)
num_gpus = config.num_gpus if config.num_gpus > 0 else n
for i in range(n):
port = config.serve_port + i
gpu = i % num_gpus
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
cmd = _bind_serve_port(base_cmd, port)
api_base = f"http://localhost:{port}/v1"
api_bases.append(api_base)
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
)
procs.append(proc)
ready = threading.Event()
ready_events.append(ready)
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
# Read whole lines and emit each line atomically under the
# shared print_lock so output from N servers stays readable.
assert p.stdout is not None
for line in iter(p.stdout.readline, ""):
with print_lock:
sys.stdout.write(f"[server-{idx}] {line}")
if not line.endswith(("\n", "\r")):
sys.stdout.write("\n")
sys.stdout.flush()
if any(m in line for m in ready_markers):
ev.set()
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
while not ev.is_set() and p.poll() is None:
if _server_is_up(base):
print(f"[server-{idx}] ready (http probe)", flush=True)
ev.set()
return
time.sleep(2)
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
def _shutdown() -> None:
for i, p in enumerate(procs):
if p.poll() is None:
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
p.send_signal(signal.SIGINT)
for p in procs:
try:
p.wait(timeout=15)
except subprocess.TimeoutExpired:
p.kill()
p.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
for i, p in enumerate(procs):
if p.poll() is not None:
raise RuntimeError(
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
)
time.sleep(2)
if any(not ev.is_set() for ev in ready_events):
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
return api_bases
def _server_is_up(api_base: str) -> bool:
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
url = api_base.rstrip("/") + "/models"
# ``api_base`` is the user-configured local-server URL we just spawned
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
# is for arbitrary user-controlled URLs with file:/ schemes which
# cannot reach this code path.
try:
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
return resp.status == 200
except Exception: # noqa: BLE001
return False
def _spawn_inference_server(config: VlmConfig) -> str:
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
accepts ``/v1/models``, and register a shutdown hook.
Streams the server's stdout/stderr to the parent terminal in
real-time on a background thread so users can see model-load
progress and errors as they happen.
Returns the full ``api_base`` URL the OpenAI client should use.
"""
cmd = config.serve_command
if not cmd:
cmd = (
f"transformers serve {shlex.quote(config.model_id)} "
f"--port {config.serve_port} --continuous-batching"
)
# Bind the single server to ``serve_port`` (what ``api_base`` below
# targets): substitute a literal ``{port}`` placeholder, else append
# ``--port``. Without this a serve_command carrying ``{port}`` would
# reach the server unsubstituted and fail to parse.
cmd = _bind_serve_port(cmd, config.serve_port)
api_base = f"http://localhost:{config.serve_port}/v1"
print(f"[server] launching: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
# Watch the server output for the uvicorn readiness banner. This is
# more reliable than polling /v1/models because transformers serve
# rescans its cache on every model-list request, which can exceed
# the urllib timeout and trigger an infinite probe loop.
ready_event = threading.Event()
# See _spawn_parallel_inference_servers for why we accept these.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
def _probe() -> None:
while not ready_event.is_set() and proc.poll() is None:
if _server_is_up(api_base):
print("[server] ready (http probe)", flush=True)
ready_event.set()
return
time.sleep(2)
threading.Thread(target=_probe, daemon=True).start()
def _stream_output() -> None:
# Read raw chunks instead of iterating lines so tqdm progress
# bars (which overwrite using \r) flush in real time.
assert proc.stdout is not None
buf = ""
prefix_started = False
while True:
ch = proc.stdout.read(1)
if ch == "":
# process exited; flush any tail
if buf:
sys.stdout.write(buf)
sys.stdout.flush()
return
if not prefix_started:
sys.stdout.write("[server] ")
prefix_started = True
sys.stdout.write(ch)
sys.stdout.flush()
buf += ch
if ch in ("\n", "\r"):
if any(marker in buf for marker in ready_markers):
ready_event.set()
buf = ""
prefix_started = False
threading.Thread(target=_stream_output, daemon=True).start()
def _shutdown() -> None:
if proc.poll() is None:
print(f"[server] stopping pid={proc.pid}", flush=True)
proc.send_signal(signal.SIGINT)
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while time.monotonic() < deadline:
if proc.poll() is not None:
raise RuntimeError(
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
f"See [server] log lines above for the cause."
)
if ready_event.wait(timeout=2):
return api_base
proc.terminate()
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
def _to_openai_messages(
messages: Sequence[dict[str, Any]],
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
"""Convert internal messages to OpenAI chat format.
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
(``fps`` from ``video_url`` blocks) are extracted out so the caller
can pass them via ``extra_body.mm_processor_kwargs`` rather than
inside the content blocks (which transformers serve rejects).
File-URL video blocks are inlined as base64 data URLs.
"""
out_messages: list[dict[str, Any]] = []
mm_kwargs: dict[str, Any] = {}
for message in messages:
content = message.get("content")
if not isinstance(content, list):
out_messages.append({"role": message["role"], "content": content})
continue
out_blocks: list[dict[str, Any]] = []
for block in content:
block_type = block.get("type") if isinstance(block, dict) else None
if block_type == "text":
out_blocks.append({"type": "text", "text": block.get("text", "")})
elif block_type == "image":
out_blocks.append(
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
)
elif block_type == "video":
frames = block.get("video", [])
for img in frames:
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
elif block_type == "video_url":
video_url = dict(block["video_url"])
url = video_url.get("url", "")
if url.startswith("file://"):
video_url["url"] = _file_to_data_url(url[len("file://") :])
out_blocks.append({"type": "video_url", "video_url": video_url})
fps = block.get("fps")
if fps is not None:
mm_kwargs["fps"] = fps
else:
out_blocks.append(block)
out_messages.append({"role": message["role"], "content": out_blocks})
return out_messages, mm_kwargs
def _file_to_data_url(path: str) -> str:
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
return f"data:video/mp4;base64,{b64}"
def _pil_to_data_url(image: Any) -> str:
"""Encode a PIL.Image as a base64 data URL."""
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/png;base64,{b64}"
@@ -1,341 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Final parquet rewrite.
For every episode the writer:
1. reads the staged module outputs,
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
3. sorts each slice deterministically,
4. broadcasts the persistent slice across every frame in the episode,
5. for each frame, materializes the sublist of event rows whose timestamp
exactly equals that frame's timestamp,
6. drops the legacy ``subtask_index`` column,
7. writes the parquet shard back in place.
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
struct for every speech atom. The tool *schema* (the description
of the ``say`` function and its parameters) is a fixed code constant —
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
it directly rather than reading a redundant per-row column.
Invariants enforced here (and re-checked by the validator):
- per-episode persistent slice is byte-identical across every frame;
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
(timestamps come straight from the source parquet — never recomputed);
- every row passes ``column_for_style(style)``.
"""
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
validate_camera_field,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
# Tool schema constants live in lerobot.datasets.language — single
# source of truth. Re-exported here so existing imports
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
# keep working.
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
# events are bucketed per-frame, but within a frame we still want determinism
return (
row.get("style") or "",
row.get("role") or "",
row.get("camera") or "",
)
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
"""Coerce a staged row into the language-column struct shape.
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
writer infers the parquet struct schema from insertion order, so
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
"""
camera = row.get("camera")
validate_camera_field(style, camera)
out: dict[str, Any] = {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
}
if with_timestamp:
out["timestamp"] = float(row["timestamp"])
out["camera"] = None if camera is None else str(camera)
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
return out
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the persistent column's struct shape."""
style = row.get("style")
if style not in PERSISTENT_STYLES:
raise ValueError(
f"persistent slice contains row with non-persistent style {style!r}; "
"row would be misrouted under column_for_style()"
)
if "timestamp" not in row:
raise ValueError(f"persistent row missing timestamp: {row!r}")
if "role" not in row:
# Friendly error from the writer instead of a raw KeyError below;
# the validator doesn't check ``role`` yet.
raise ValueError(f"persistent row missing role: {row!r}")
return _normalize_row(row, style, with_timestamp=True)
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
style = row.get("style")
if style is not None and style not in EVENT_ONLY_STYLES:
raise ValueError(
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
)
if column_for_style(style) != LANGUAGE_EVENTS:
raise ValueError(f"event row with style {style!r} would not route to language_events")
if "role" not in row:
raise ValueError(f"event row missing role: {row!r}")
return _normalize_row(row, style, with_timestamp=False)
def _normalize_tool_calls(value: Any) -> list[Any] | None:
if value is None:
return None
if not isinstance(value, list):
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
return list(value)
def _validate_atom_invariants(row: dict[str, Any]) -> None:
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
has_content = row.get("content") is not None
has_tools = row.get("tool_calls") is not None
if not (has_content or has_tools):
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
if row.get("style") is None and not has_tools:
raise ValueError(f"style=None requires tool_calls: {row!r}")
def _validate_speech_atom(row: dict[str, Any]) -> None:
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
if row.get("style") is not None:
return # not a speech atom
if row.get("role") != "assistant":
raise ValueError(f"speech atom must have role=assistant: {row!r}")
if row.get("content") is not None:
raise ValueError(f"speech atom must have content=null: {row!r}")
tool_calls = row.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
first = tool_calls[0]
if not isinstance(first, dict):
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
if first.get("type") != "function":
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
fn = first.get("function") or {}
if fn.get("name") != "say":
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
args = fn.get("arguments") or {}
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
@dataclass
class LanguageColumnsWriter:
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
drop_existing_subtask_index: bool = True
def write_all(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> list[Path]:
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
for record in records:
episodes_by_path[record.data_path].append(record)
written: list[Path] = []
for path, eps in episodes_by_path.items():
self._rewrite_one(path, eps, staging_dir, root)
written.append(path)
return written
def _rewrite_one(
self,
path: Path,
episodes: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> None:
table = pq.read_table(path)
n_rows = table.num_rows
# Ensure we cover every episode in the file. Episodes that don't have
# staging artifacts are passed through with empty annotation lists —
# this keeps the writer idempotent and safe for partial reruns.
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
for record in episodes:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged_per_ep[record.episode_index] = staging.read_all()
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
for ep_index, ep_staged in staged_per_ep.items():
persistent_rows: list[dict[str, Any]] = []
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
for _module_name, rows in ep_staged.items():
for row in rows:
style = row.get("style")
if column_for_style(style) == LANGUAGE_PERSISTENT:
persistent_rows.append(row)
else:
event_rows.append(row)
persistent_rows.sort(key=_row_persistent_sort_key)
normalized_persistent = []
for r in persistent_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
normalized_persistent.append(_normalize_persistent_row(r))
persistent_by_ep[ep_index] = normalized_persistent
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
for r in event_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
ts = float(r["timestamp"])
buckets[ts].append(_normalize_event_row(r))
for ts in list(buckets.keys()):
buckets[ts].sort(key=_row_event_sort_key)
events_by_ep_ts[ep_index] = buckets
episode_col = (
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
)
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
if episode_col is None or ts_col is None:
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
per_row_persistent: list[list[dict[str, Any]]] = []
per_row_events: list[list[dict[str, Any]]] = []
for i in range(n_rows):
ep = episode_col[i]
ts = float(ts_col[i])
per_row_persistent.append(persistent_by_ep.get(ep, []))
buckets = events_by_ep_ts.get(ep, {})
per_row_events.append(buckets.get(ts, []))
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
# Atomic replace: write to a sibling tmp path and rename so a crash
# mid-write can't leave a half-written shard that ``pq.read_table``
# would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp")
pq.write_table(new_table, tmp_path)
tmp_path.replace(path)
def _materialize_table(
self,
table: pa.Table,
persistent: list[list[dict[str, Any]]],
events: list[list[dict[str, Any]]],
*,
drop_old: bool,
) -> pa.Table:
cols = []
names = []
for name in table.column_names:
if drop_old and name == "subtask_index":
continue
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
continue # we'll re-add canonical versions
# Strip any legacy ``tools`` column previously emitted by older
# writers — the schema no longer uses it (constant lives in
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
if name == "tools":
continue
cols.append(table.column(name))
names.append(name)
# We let pyarrow infer struct/list schema rather than passing the
# canonical type from `lerobot.datasets.language` directly: that type
# uses `pa.json_()` for the `tool_calls` element type, which
# `pa.array(..., type=...)` cannot materialize from Python lists on
# current pyarrow versions. The inferred schema round-trips through
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
# exercises the same flow.
persistent_arr = pa.array(persistent)
events_arr = pa.array(events)
cols.extend([persistent_arr, events_arr])
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
return pa.Table.from_arrays(cols, names=names)
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
"""Build a canonical speech tool-call atom for the events column."""
return {
"role": "assistant",
"content": None,
"style": None,
"timestamp": float(timestamp),
"camera": None,
"tool_calls": [
{
"type": "function",
"function": {
"name": "say",
"arguments": {"text": text},
},
}
],
}
+2 -3
View File
@@ -105,9 +105,8 @@ def raw_observation_to_observation(
def prepare_image(image: torch.Tensor) -> torch.Tensor:
"""Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
if image.dtype == torch.uint8:
image = image.type(torch.float32) / 255
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
image = image.type(torch.float32) / 255
image = image.contiguous()
return image
+2 -5
View File
@@ -436,7 +436,7 @@ class OpenCVCamera(Camera):
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame (blocking call)
1. Reads a color frame
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
@@ -445,9 +445,8 @@ class OpenCVCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0
while not stop_event.is_set():
while not self.stop_event.is_set():
try:
raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame)
@@ -485,8 +484,6 @@ class OpenCVCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None
@@ -268,13 +268,13 @@ class RealSenseCamera(Camera):
)
if len(found_devices) > 1:
serial_numbers = [dev["id"] for dev in found_devices]
serial_numbers = [dev["serial_number"] for dev in found_devices]
raise ValueError(
f"Multiple RealSense cameras found with name '{name}'. "
f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
)
serial_number = str(found_devices[0]["id"])
serial_number = str(found_devices[0]["serial_number"])
return serial_number
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
from the camera hardware via the RealSense pipeline.
Returns:
np.ndarray: The depth map as a NumPy array (height, width, 1)
of type `np.uint16` (raw depth values in millimeters).
np.ndarray: The depth map as a NumPy array (height, width)
of type `np.uint16` (raw depth values in millimeters) and rotation.
Raises:
DeviceNotConnectedError: If the camera is not connected.
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color/depth frame (blocking call with 10s timeout)
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
1. Reads a color frame with 500ms timeout
2. Stores result in latest_frame and updates timestamp (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
@@ -474,9 +474,8 @@ class RealSenseCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
stop_event = self.stop_event
failure_count = 0
while not stop_event.is_set():
while not self.stop_event.is_set():
try:
frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame()
@@ -487,8 +486,6 @@ class RealSenseCamera(Camera):
depth_frame_raw = frame.get_depth_frame()
depth_frame = np.asanyarray(depth_frame_raw.get_data())
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
processed_depth_frame = processed_depth_frame[..., np.newaxis]
capture_time = time.perf_counter()
@@ -525,8 +522,6 @@ class RealSenseCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive(): # pragma: no cover
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None
@@ -537,6 +532,7 @@ class RealSenseCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
@@ -579,6 +575,7 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
@@ -614,73 +611,6 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[np.uint16]:
"""Read the latest depth frame asynchronously, in millimeters.
Mirrors :meth:`async_read` but returns the depth stream rather than the
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``, where each
pixel is the distance from the sensor in millimeters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
the background read thread is not running.
TimeoutError: If no frame becomes available within ``timeout_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
self.new_frame_event.clear()
if depth_frame is None:
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
return depth_frame
@check_if_not_connected
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent depth frame in millimeters (peeking).
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
Output is ``np.uint16`` of shape ``(H, W, 1)``, where each pixel is the
distance from the sensor in millimeters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
no depth frame has been captured yet.
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
"""
if not self.use_depth:
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
timestamp = self.latest_timestamp
if depth_frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any depth frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return depth_frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
+1 -4
View File
@@ -249,9 +249,8 @@ class ZMQCamera(Camera):
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.")
stop_event = self.stop_event
failure_count = 0
while not stop_event.is_set():
while not self.stop_event.is_set():
try:
frame = self._read_from_hardware()
capture_time = time.perf_counter()
@@ -293,8 +292,6 @@ class ZMQCamera(Camera):
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
if self.thread.is_alive():
logger.warning(f"{self} read thread did not terminate within timeout.")
self.thread = None
self.stop_event = None
-70
View File
@@ -18,7 +18,6 @@ from __future__ import annotations
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
@@ -244,72 +243,3 @@ def sanity_check_dataset_robot_compatibility(
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)
########################################################################################
# Teleoperator smooth handover helpers
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
# being a root module.
########################################################################################
def teleop_supports_feedback(teleop) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback (i.e. have non-empty
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
``target_pos`` is expected to be in the teleop's action/feedback key space.
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
the robot action key space directly.
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
follower robot does not receive new actions, which could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def follower_smooth_move_to(
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm to
the follower, the follower is brought to the teleop's current pose so the
robot meets the operator's hand rather than jumping to it on the first frame.
Both ``current`` and ``target`` must be in the robot action key space
(i.e. the output of ``robot_action_processor``).
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
+4 -37
View File
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
) -> None:
state: dict = {"step": step}
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
# produced `step`, since both scale how many sampler positions a step consumes (see
# compute_sampler_state).
if num_processes is not None:
state["num_processes"] = num_processes
if batch_size is not None:
state["batch_size"] = batch_size
write_json(state, save_dir / TRAINING_STEP)
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
return training_step["step"]
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
@@ -96,8 +75,6 @@ def save_checkpoint(
scheduler: LRScheduler | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -123,10 +100,6 @@ def save_checkpoint(
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
num_processes (int | None, optional): Distributed world size to record for sample-exact
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
@@ -139,9 +112,7 @@ def save_checkpoint(
preprocessor.save_pretrained(pretrained_dir)
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
def save_training_state(
@@ -149,8 +120,6 @@ def save_training_state(
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
@@ -162,12 +131,10 @@ def save_training_state(
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_training_step(train_step, save_dir)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
-7
View File
@@ -35,11 +35,8 @@ from .types import (
from .video import (
VALID_VIDEO_CODECS,
VIDEO_ENCODER_INFO_KEYS,
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
encoder_config_from_video_info,
)
__all__ = [
@@ -60,12 +57,8 @@ __all__ = [
"WandBConfig",
"load_recipe",
"VideoEncoderConfig",
"DepthEncoderConfig",
# Defaults
"camera_encoder_defaults",
"depth_encoder_defaults",
# Factories
"encoder_config_from_video_info",
# Constants
"VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS",
+3 -5
View File
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
from .video import VideoEncoderConfig, camera_encoder_defaults
@dataclass
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
private: bool | None = None
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
@@ -60,8 +60,6 @@ class DatasetRecordConfig:
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
+1 -6
View File
@@ -35,17 +35,12 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_video_backend)
# When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
# Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres).
# Has no effect on datasets without depth cameras.
depth_output_unit: str = "mm"
streaming: bool = False
def __post_init__(self) -> None:
if self.depth_output_unit not in ("m", "mm"):
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
if self.episodes is not None:
if any(ep < 0 for ep in self.episodes):
raise ValueError(
-6
View File
@@ -177,12 +177,6 @@ class TrainPipelineConfig(HubMixin):
)
active_cfg = self.trainable_config
if self.rename_map and active_cfg.pretrained_path is None:
raise ValueError(
"`rename_map` requires a pretrained policy checkpoint. "
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{active_cfg.type}"
+8 -112
View File
@@ -20,7 +20,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, ClassVar, Self
from typing import Any
from lerobot.utils.import_utils import require_package
@@ -36,12 +36,11 @@ HW_VIDEO_CODECS = [
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
)
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
LIBSVTAV1_DEFAULT_PRESET: int = 12
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
@@ -53,19 +52,6 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
)
# Default depth quantization and encoding parameters.
DEPTH_QUANT_BITS: int = 12
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
DEFAULT_DEPTH_MIN: float = 0.01
DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
@dataclass
class VideoEncoderConfig:
@@ -100,10 +86,6 @@ class VideoEncoderConfig:
video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict)
# Source-data channel count this encoder is expected to handle (3 for RGB,
# 1 for depth, etc.)
_DEFAULT_CHANNELS: ClassVar[int] = 3
def __post_init__(self) -> None:
self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
@@ -112,9 +94,9 @@ class VideoEncoderConfig:
self.validate()
@classmethod
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
"""Parse the ``video.*`` keys of a feature ``info`` block into
constructor kwargs.
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
Missing or ``None`` values fall back to the class defaults.
"""
video_info = video_info or {}
kwargs: dict[str, Any] = {}
@@ -133,15 +115,7 @@ class VideoEncoderConfig:
continue
kwargs[field_name] = value
return kwargs
@classmethod
def from_video_info(cls, video_info: dict | None) -> Self:
"""Reconstruct an encoder config from a video feature's ``info`` block.
Missing or ``None`` values fall back to the class defaults.
"""
return cls(**cls._kwargs_from_video_info(video_info))
return cls(**kwargs)
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
"""Return the subset of available encoders based on the specified video backend.
@@ -164,9 +138,7 @@ class VideoEncoderConfig:
require_package("av", extra="dataset")
from lerobot.datasets import check_video_encoder_parameters_pyav
check_video_encoder_parameters_pyav(
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
)
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
def resolve_vcodec(self) -> None:
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
@@ -246,10 +218,6 @@ class VideoEncoderConfig:
elif self.vcodec == "h264_qsv":
set_if("global_quality", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "ffv1":
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
# are not meaningful.
set_if("threads", encoder_threads)
else:
set_if("crf", self.crf)
set_if("preset", self.preset)
@@ -265,75 +233,3 @@ class VideoEncoderConfig:
def camera_encoder_defaults() -> VideoEncoderConfig:
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
return VideoEncoderConfig()
@dataclass
class DepthEncoderConfig(VideoEncoderConfig):
"""Encoder configuration for depth-map streams.
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
preset, ``extra_options``) and adds the four parameters of the depth
quantizer.
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
to ``"gray12le"``.
Attributes:
depth_min: Minimum depth in physical units (e.g. metres) represented
by quantum ``0``.
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
shift: Pre-log offset for numerical stability near zero.
use_log: ``True`` for logarithmic quantization (default; matches
sensor error profile), ``False`` for linear.
"""
vcodec: str = "hevc"
pix_fmt: str = "gray12le"
depth_min: float = DEFAULT_DEPTH_MIN
depth_max: float = DEFAULT_DEPTH_MAX
shift: float = DEFAULT_DEPTH_SHIFT
use_log: bool = DEFAULT_DEPTH_USE_LOG
_DEFAULT_CHANNELS: ClassVar[int] = 1
@classmethod
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
"""Layer the depth-specific tuning (``depth_min`` / ``depth_max`` /
``shift`` / ``use_log``) on top of the base parser. Missing keys
fall back to the class defaults.
"""
kwargs = super()._kwargs_from_video_info(video_info)
video_info = video_info or {}
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
value = video_info.get(f"video.{name}")
if value is not None:
kwargs[name] = value
return kwargs
def depth_encoder_defaults() -> DepthEncoderConfig:
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
return DepthEncoderConfig()
def encoder_config_from_video_info(video_info: dict | None) -> VideoEncoderConfig:
"""Build the appropriate encoder config from a feature's ``info`` block.
Dispatches to :class:`DepthEncoderConfig` when the dict marks the feature
as a depth map and to :class:`VideoEncoderConfig`
otherwise.
Args:
video_info: A feature's ``info`` dict as persisted in ``info.json``,
or ``None`` (treated as an empty dict).
Returns:
A :class:`DepthEncoderConfig` for depth features, otherwise a
:class:`VideoEncoderConfig`.
"""
video_info = video_info or {}
is_depth = bool(video_info.get("is_depth_map") or video_info.get("video.is_depth_map"))
cls: type[VideoEncoderConfig] = DepthEncoderConfig if is_depth else VideoEncoderConfig
return cls.from_video_info(video_info)
+1 -2
View File
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler, compute_sampler_state
from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -82,7 +82,6 @@ __all__ = [
"aggregate_stats",
"convert_image_to_video_dataset",
"create_initial_features",
"compute_sampler_state",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
+6 -21
View File
@@ -286,8 +286,6 @@ def aggregate_datasets(
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
chunk_size: int | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -305,8 +303,6 @@ def aggregate_datasets(
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
logging.info("Start aggregate_datasets")
@@ -355,12 +351,8 @@ def aggregate_datasets(
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
)
data_idx = aggregate_data(
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
)
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
@@ -375,9 +367,7 @@ def aggregate_datasets(
logging.info("Aggregation complete.")
def aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
):
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
"""Aggregates video chunks from a source dataset into the destination dataset.
Handles video file concatenation and rotation based on file size limits.
@@ -389,7 +379,6 @@ def aggregate_videos(
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
@@ -450,7 +439,7 @@ def aggregate_videos(
src_size = get_file_size_in_mb(src_path)
dst_size = get_file_size_in_mb(dst_path)
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new file - offset is 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_key = (chunk_idx, file_idx)
@@ -488,7 +477,7 @@ def aggregate_videos(
return videos_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
@@ -504,7 +493,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -550,7 +538,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
contains_images=contains_images,
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
)
# Record the mapping from source to actual destination
@@ -627,7 +614,6 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -644,7 +630,6 @@ def append_or_create_parquet_file(
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -664,7 +649,7 @@ def append_or_create_parquet_file(
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if not concatenate or dst_size + src_size >= max_mb:
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
+5 -15
View File
@@ -59,8 +59,6 @@ class RunningQuantileStats:
batch: An array where all dimensions except the last are batch dimensions.
"""
batch = batch.reshape(-1, batch.shape[-1])
# Promote integer and low-precision inputs before computing squared statistics.
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
num_elements, vector_length = batch.shape
if self._count == 0:
@@ -506,10 +504,8 @@ def compute_episode_stats(
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
Note:
For 'image'/'video' features, stats are computed per channel and kept with a
leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
this rescaling and remain in their stored units.
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
per-channel values when dtype is 'image' or 'video'.
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
@@ -533,12 +529,8 @@ def compute_episode_stats(
)
if features[key]["dtype"] in ["image", "video"]:
normalization_factor = (
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
)
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
@@ -558,10 +550,8 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
if key == "count" and value.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
raise ValueError(
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
)
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
+7 -48
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Callable, Iterable
from collections.abc import Callable
from pathlib import Path
import numpy as np
@@ -337,25 +337,6 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def depth_keys(self) -> list[str]:
"""Keys to access depth-map modalities stored as videos or images.
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
"""
def _is_depth(ft: dict) -> bool:
info = ft.get("info") or {}
video_info = ft.get("video_info") or {}
return (
info.get("is_depth_map", False)
or info.get("video.is_depth_map", False)
or video_info.get("video.is_depth_map", False)
)
return [key for key, ft in self.features.items() if _is_depth(ft)]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
@@ -599,51 +580,29 @@ class LeRobotDatasetMetadata:
def update_video_info(
self,
video_key: str | None = None,
video_encoder: VideoEncoderConfig | None = None,
preserve_keys: Iterable[str] | None = None,
camera_encoder: VideoEncoderConfig | None = None,
) -> None:
"""Populate or refresh per-feature video info in ``info.json``.
"""Populate per-feature video info in ``info.json``.
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
Two modes, selected by ``preserve_keys``:
- **Populate** (``None``, default): write info for video keys that lack it,
skip the rest. Used when first encoding a dataset.
- **Refresh** (any iterable): re-probe and overwrite existing info, keeping
the listed keys. Used after re-encoding to preserve data-intrinsic entries
(``is_depth_map``, depth quantization params) while codec params change.
Args:
video_key: If provided, only update this video key. Otherwise update
all video keys in the dataset.
video_encoder: Encoder configuration used to produce the
camera_encoder: Encoder configuration used to produce the
videos. When provided, its fields are recorded as
``video.<field>`` entries alongside the stream-derived
``video.*`` entries (see :func:`get_video_info`).
preserve_keys: ``None`` (default) for populate-once mode. An iterable
(possibly empty) switches to refresh mode, keeping these keys'
existing values while recomputing the rest.
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
refresh = preserve_keys is not None
preserve_set = set(preserve_keys or ())
for key in video_keys:
existing = self.features[key].get("info") or {}
# ``is_depth_map`` is written at feature creation and does not count as real video info here.
already_populated = bool(set(existing.keys()) - {"is_depth_map"})
# Populate-once: never clobber info that has already been written unless a refresh is requested.
if already_populated and not refresh:
continue
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
new_info = get_video_info(video_path, video_encoder=video_encoder)
# Drop preserved keys so the existing values win on merge.
new_info = {k: v for k, v in new_info.items() if k not in preserve_set}
self.info.features[key]["info"] = {**existing, **new_info}
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
def update_chunk_settings(
self,
-26
View File
@@ -22,10 +22,7 @@ from pathlib import Path
import datasets
import torch
from lerobot.configs.video import DepthEncoderConfig
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import (
check_delta_timestamps,
get_delta_indices,
@@ -54,7 +51,6 @@ class DatasetReader:
delta_timestamps: dict[str, list[float]] | None,
image_transforms: Callable | None,
return_uint8: bool = False,
depth_output_unit: str = "mm",
):
"""Initialize the reader with metadata, filtering, and transform config.
@@ -72,10 +68,6 @@ class DatasetReader:
relative timestamp offsets for temporal context windows.
image_transforms: Optional torchvision v2 transform applied to
visual features.
return_uint8: If True, return RGB video frames as raw uint8 tensors
instead of normalized float32.
depth_output_unit: Physical unit depth maps are dequantized to
(``"m"`` or ``"mm"``). Defaults to ``"mm"``.
"""
self._meta = meta
self.root = root
@@ -84,7 +76,6 @@ class DatasetReader:
self._video_backend = video_backend
self._image_transforms = image_transforms
self._return_uint8 = return_uint8
self._depth_output_unit = depth_output_unit
self.hf_dataset: datasets.Dataset | None = None
self._absolute_to_relative_idx: dict[int, int] | None = None
@@ -95,12 +86,6 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
@@ -262,18 +247,7 @@ class DatasetReader:
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
is_depth=vid_key in self._meta.depth_keys,
)
if vid_key in self._meta.depth_keys:
depth_encoder = self._depth_encoder_configs[vid_key]
frames = dequantize_depth(
frames,
depth_min=depth_encoder.depth_min,
depth_max=depth_encoder.depth_max,
shift=depth_encoder.shift,
use_log=depth_encoder.use_log,
output_unit=self._depth_output_unit,
)
return vid_key, frames.squeeze(0)
items = list(query_timestamps.items())
+64 -109
View File
@@ -36,14 +36,7 @@ import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
encoder_config_from_video_info,
)
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
from lerobot.utils.utils import flatten_dict
@@ -54,7 +47,6 @@ from .compute_stats import (
compute_relative_action_stats,
)
from .dataset_metadata import LeRobotDatasetMetadata
from .image_writer import write_image
from .io_utils import (
get_parquet_file_size_in_mb,
load_episodes,
@@ -69,13 +61,12 @@ from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEPTH_FILE_PATTERN,
IMAGE_FILE_PATTERN,
VIDEO_DIR,
update_chunk_file_indices,
)
from .video_utils import (
encode_video_frames,
get_video_info,
reencode_video,
)
@@ -270,8 +261,6 @@ def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
@@ -281,8 +270,6 @@ def merge_datasets(
datasets: List of LeRobotDatasets to merge.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -297,8 +284,6 @@ def merge_datasets(
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
concatenate_videos=concatenate_videos,
concatenate_data=concatenate_data,
)
merged_dataset = LeRobotDataset(
@@ -609,7 +594,7 @@ def _keep_episodes_from_video_with_av(
output_path: Path,
episodes_to_keep: list[tuple[int, int]],
fps: float,
video_encoder: VideoEncoderConfig,
camera_encoder: VideoEncoderConfig,
) -> None:
"""Keep only specified episodes from a video file using PyAV.
@@ -623,7 +608,7 @@ def _keep_episodes_from_video_with_av(
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
fps: Frame rate of the video.
video_encoder: Video encoder settings used to re-encode the kept frames.
camera_encoder: Video encoder settings used to re-encode the kept frames.
"""
from fractions import Fraction
@@ -648,13 +633,13 @@ def _keep_episodes_from_video_with_av(
# Convert fps to Fraction for PyAV compatibility.
fps_fraction = Fraction(fps).limit_denominator(1000)
codec_options = video_encoder.get_codec_options(as_strings=True)
v_out = out.add_stream(video_encoder.vcodec, rate=fps_fraction, options=codec_options)
codec_options = camera_encoder.get_codec_options(as_strings=True)
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
v_out.width = v_in.codec_context.width
v_out.height = v_in.codec_context.height
v_out.pix_fmt = video_encoder.pix_fmt
v_out.pix_fmt = camera_encoder.pix_fmt
# Set time_base to match the frame rate for proper timestamp handling.
v_out.time_base = Fraction(1, int(fps))
@@ -741,7 +726,7 @@ def _copy_and_reindex_videos(
for video_key in src_dataset.meta.video_keys:
logging.info(f"Processing videos for {video_key}")
video_encoder = encoder_config_from_video_info(
camera_encoder = VideoEncoderConfig.from_video_info(
src_dataset.meta.info.features.get(video_key, {}).get("info")
)
@@ -825,7 +810,7 @@ def _copy_and_reindex_videos(
dst_video_path,
episodes_to_keep_ranges,
src_dataset.meta.fps,
video_encoder,
camera_encoder,
)
cumulative_ts = 0.0
@@ -1159,15 +1144,15 @@ def _save_episode_images_for_video(
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
write_image(item[img_key], imgs_dir / frame_pattern.format(frame_index=i))
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -1199,14 +1184,13 @@ def _save_batch_episodes_images(
hf_dataset = dataset.hf_dataset.with_format(None)
imgs_dataset = hf_dataset.select_columns(img_key)
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
# Define function to save a single image with global frame index
# Defined once outside the loop to avoid repeated closure creation
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
i, item = i_item_tuple
write_image(item[img_key_param], imgs_dir / frame_pattern.format(frame_index=base_frame_idx + i))
img = item[img_key_param]
# Use global frame index for naming
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
return i
episode_durations = []
@@ -1297,7 +1281,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: list[int],
temp_dir: Path,
fps: int,
video_encoder: VideoEncoderConfig,
camera_encoder: VideoEncoderConfig,
num_calibration_frames: int = 30,
) -> float:
"""Estimate MB per frame by encoding a small calibration sample.
@@ -1311,7 +1295,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding.
video_encoder: Video encoder settings used for calibration encoding.
camera_encoder: Video encoder settings used for calibration encoding.
num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns:
@@ -1336,11 +1320,10 @@ def _estimate_frame_size_via_calibration(
hf_dataset = dataset.hf_dataset.with_format(None)
sample_indices = range(from_idx, from_idx + num_frames)
# Save calibration frames using the suffix/format the encoder expects.
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
# Save calibration frames
for i, idx in enumerate(sample_indices):
write_image(hf_dataset[idx][img_key], calibration_dir / frame_pattern.format(frame_index=i))
img = hf_dataset[idx][img_key]
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
# Encode calibration video
calibration_video_path = calibration_dir / "calibration.mp4"
@@ -1348,7 +1331,7 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
video_encoder=video_encoder,
camera_encoder=camera_encoder,
overwrite=True,
)
@@ -1621,7 +1604,6 @@ def recompute_stats(
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
# TODO: enable image and video stats re-computation
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
@@ -1668,7 +1650,6 @@ def convert_image_to_video_dataset(
output_dir: Path | None = None,
repo_id: str | None = None,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
episode_indices: list[int] | None = None,
num_workers: int = 4,
max_episodes_per_batch: int | None = None,
@@ -1680,32 +1661,21 @@ def convert_image_to_video_dataset(
LeRobot dataset structure with videos stored in chunked MP4 files.
Args:
dataset: The source LeRobot dataset with images.
output_dir: Root directory where the converted dataset will be stored. When
``None``, defaults to ``$HF_LEROBOT_HOME/repo_id``. Equivalent to
``new_root`` in ``EditDatasetConfig``.
repo_id: Converted dataset identifier. Equivalent to ``new_repo_id`` in
``EditDatasetConfig``.
camera_encoder: Video encoder settings applied to RGB cameras. When ``None``,
:func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings applied to depth-map cameras, including
the quantization parameters persisted to the dataset metadata. When
``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
episode_indices: Episode indices to convert. When ``None``, all episodes are
converted.
num_workers: Number of threads for parallel processing.
max_episodes_per_batch: Maximum episodes per video batch, to bound memory use.
``None`` means no limit.
max_frames_per_batch: Maximum frames per video batch, to bound memory use.
``None`` means no limit.
dataset: The source LeRobot dataset with images
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
camera_encoder: Video encoder settings
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
Returns:
A new :class:`LeRobotDataset` with images encoded as videos.
New LeRobotDataset with images encoded as videos
"""
if camera_encoder is None:
camera_encoder = camera_encoder_defaults()
if depth_encoder is None:
depth_encoder = depth_encoder_defaults()
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
@@ -1730,7 +1700,10 @@ def convert_image_to_video_dataset(
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}")
logging.info(
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
)
# Create new features dict, converting image features to video features
new_features = {}
@@ -1792,8 +1765,6 @@ def convert_image_to_video_dataset(
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
for img_key in tqdm(img_keys, desc="Processing cameras"):
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
# Estimate size per frame by encoding a small calibration sample
# This provides accurate compression ratio for the specific codec parameters
size_per_frame_mb = _estimate_frame_size_via_calibration(
@@ -1802,7 +1773,7 @@ def convert_image_to_video_dataset(
episode_indices=episode_indices,
temp_dir=temp_dir,
fps=fps,
video_encoder=target_encoder,
camera_encoder=camera_encoder,
)
logging.info(f"Processing camera: {img_key}")
@@ -1844,7 +1815,7 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
video_encoder=target_encoder,
camera_encoder=camera_encoder,
overwrite=True,
)
@@ -1883,11 +1854,16 @@ def convert_image_to_video_dataset(
new_meta.info.total_tasks = dataset.meta.total_tasks
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos). They are registered as
# video features above, so update_video_info populates their (still-empty) info.
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
for img_key in img_keys:
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
new_meta.update_video_info(video_key=img_key, video_encoder=target_encoder)
if not new_meta.features[img_key].get("info", None):
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info.features[img_key]["info"] = get_video_info(
video_path, camera_encoder=camera_encoder
)
write_info(new_meta.info, new_meta.root)
@@ -1914,11 +1890,11 @@ def convert_image_to_video_dataset(
def _reencode_video_worker(args: tuple) -> Path:
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
video_path, video_encoder, encoder_threads = args
video_path, camera_encoder, encoder_threads = args
reencode_video(
input_video_path=video_path,
output_video_path=video_path,
video_encoder=video_encoder,
camera_encoder=camera_encoder,
encoder_threads=encoder_threads,
overwrite=True,
)
@@ -1927,8 +1903,7 @@ def _reencode_video_worker(args: tuple) -> Path:
def reencode_dataset(
dataset: LeRobotDataset,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig,
encoder_threads: int | None = None,
num_workers: int | None = None,
) -> LeRobotDataset:
@@ -1939,11 +1914,8 @@ def reencode_dataset(
Args:
dataset: An existing :class:`LeRobotDataset` whose videos will be
re-encoded.
camera_encoder: Target encoder configuration applied to every RGB video
file. If ``None``, re-encoding is skipped for RGB videos.
depth_encoder: Target encoder configuration applied to every depth video
file. If ``None``, re-encoding is skipped for depth videos.
Quantization parameters will not override the ones in the current dataset.
camera_encoder: Target encoder configuration applied to every video
file.
encoder_threads: Per-encoder thread count forwarded to
:func:`reencode_video`. ``None`` lets the codec decide.
num_workers: Number of parallel processes. ``None`` or ``0`` means
@@ -1955,35 +1927,23 @@ def reencode_dataset(
on disk.
"""
meta = dataset.meta
video_keys_encoders_dict = {}
video_keys_paths_dict = {}
if camera_encoder is None and depth_encoder is None:
raise ValueError("Either camera_encoder or depth_encoder must be provided")
video_paths_list = []
# Only re-encode if the videos are not already encoded with the given video encoding parameters
for video_key in meta.video_keys:
current_info = meta.info.features[video_key].get("info", {})
current_encoder = encoder_config_from_video_info(current_info)
target_encoder = depth_encoder if video_key in meta.depth_keys else camera_encoder
if target_encoder is None:
logging.info(f"No encoder provided for {video_key} video. Skipping re-encoding.")
elif current_encoder != target_encoder:
video_keys_paths_dict[video_key] = list((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
video_keys_encoders_dict[video_key] = target_encoder
current_encoder = VideoEncoderConfig.from_video_info(current_info)
if current_encoder != camera_encoder:
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
else:
logging.info(f"{video_key} videos are already encoded with {target_encoder}. Nothing to do.")
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
if len(video_keys_paths_dict) == 0:
if len(video_paths_list) == 0:
logging.warning("Dataset has no videos to re-encode.")
return dataset
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
worker_args = [
(path, encoder, encoder_threads)
for video_key, encoder in video_keys_encoders_dict.items()
for path in video_keys_paths_dict[video_key]
]
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
if num_workers and num_workers > 1:
with ProcessPoolExecutor(max_workers=num_workers) as pool:
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
@@ -1997,15 +1957,10 @@ def reencode_dataset(
for args in tqdm(worker_args, desc="Re-encoding videos"):
_reencode_video_worker(args)
# Refresh video info in metadata for every re-encoded key. Re-encoding only
# changes codec/container params, so for depth videos we preserve ``is_depth_map``
# and the depth quantization params (``video.depth_min`` / ``video.depth_max`` /
# ...), which describe the data rather than the codec and must survive a transcode.
# RGB videos pass an empty set: still a refresh, but nothing to preserve.
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
for video_key, encoder in video_keys_encoders_dict.items():
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else set()
meta.update_video_info(video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys)
# Refresh video info in metadata for every video key.
for vid_key in meta.video_keys:
video_path = meta.root / meta.get_video_file_path(0, vid_key)
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
write_info(meta.info, meta.root)
logging.info("Dataset metadata updated.")
+12 -41
View File
@@ -31,12 +31,7 @@ import PIL.Image
import pyarrow.parquet as pq
import torch
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
from .compute_stats import compute_episode_stats
from .dataset_metadata import LeRobotDatasetMetadata
@@ -53,7 +48,6 @@ from .io_utils import (
write_info,
)
from .utils import (
DEFAULT_DEPTH_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
update_chunk_file_indices,
@@ -73,22 +67,17 @@ def _encode_video_worker(
episode_index: int,
root: Path,
fps: int,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
path_template = (
DEFAULT_DEPTH_PATH
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
else DEFAULT_IMAGE_PATH
)
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir,
temp_path,
fps,
video_encoder=video_encoder,
camera_encoder=camera_encoder,
encoder_threads=encoder_threads,
overwrite=True,
)
@@ -108,7 +97,6 @@ class DatasetWriter:
meta: LeRobotDatasetMetadata,
root: Path,
camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_threads: int | None,
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
@@ -120,11 +108,8 @@ class DatasetWriter:
meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence).
root: Local dataset root directory.
camera_encoder: Video encoder settings applied to RGB cameras. When
``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings applied to depth cameras, including
the quantization parameters. When ``None``,
:func:`~lerobot.configs.video.depth_encoder_defaults` is used.
camera_encoder: Video encoder settings applied to all cameras.
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
batch_encoding_size: Number of episodes to accumulate before
@@ -136,7 +121,6 @@ class DatasetWriter:
self._meta = meta
self._root = root
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
@@ -161,8 +145,7 @@ class DatasetWriter:
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
fpath = path_template.format(
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self._root / fpath
@@ -212,7 +195,6 @@ class DatasetWriter:
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
video_keys=list(self._meta.video_keys),
depth_video_keys=list(self._meta.depth_keys),
temp_dir=self._root,
)
@@ -300,13 +282,10 @@ class DatasetWriter:
if use_streaming:
streaming_results = self._streaming_encoder.finish_episode()
for video_key in self._meta.video_keys:
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
temp_path, video_stats = streaming_results[video_key]
if video_stats is not None:
ep_stats[video_key] = {
k: v
if k == "count"
else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
@@ -321,9 +300,7 @@ class DatasetWriter:
episode_index,
self._root,
self._meta.fps,
self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
self._camera_encoder,
self._encoder_threads,
): video_key
for video_key in self._meta.video_keys
@@ -534,12 +511,7 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(
video_key,
video_encoder=self._depth_encoder
if video_key in self._meta.depth_keys
else self._camera_encoder,
)
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
write_info(self._meta.info, self._meta.root)
metadata = {
@@ -606,14 +578,13 @@ class DatasetWriter:
self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
is_depth = video_key in self._meta.depth_keys
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
return _encode_video_worker(
video_key,
episode_index,
self._root,
self._meta.fps,
self._depth_encoder if is_depth else self._camera_encoder,
self._camera_encoder,
self._encoder_threads,
)
-256
View File
@@ -1,256 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
"""
import math
from typing import Literal
import av
import numpy as np
import torch
from numpy.typing import NDArray
from lerobot.configs.video import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_PIX_FMT,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
DEPTH_QMAX,
)
from .pyav_utils import write_u16_plane
_MM_PER_METRE = 1000.0
_UINT16_MAX = 65535
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite."""
if depth_min + shift <= 0:
raise ValueError(
f"depth_min + shift must be positive for logarithmic quantization, "
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
)
def _depth_input_to_float32_and_unit(
depth: NDArray[np.integer] | NDArray[np.floating],
input_unit: Literal["auto", "m", "mm"],
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
resolved_unit = (
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
)
return depth.astype(np.float32, order="K"), resolved_unit
def quantize_depth(
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
video_backend: str | None = "pyav",
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16] | av.VideoFrame:
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
Depth maps are packed into 12-bit integer frames so they fit in standard
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
Logarithmic quantization is the default because it allocates more quanta
to near-range depth, which matches the (1/depth) error profile of typical
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
**Input units**:
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
- ``input_unit="mm"``: interpret input values as millimetres.
- ``input_unit="m"``: interpret input values as metres.
Quantization math runs in the **resolved input unit**.
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
Args:
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
depth_min: Depth (metres) at quantum ``0``.
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
use_log: If ``True`` (default), quantize in log space.
video_backend: Video backend to use for encoding. Defaults to "pyav".
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
Returns:
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
``[0, DEPTH_QMAX]``.
Raises:
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if input_unit not in ("auto", "m", "mm"):
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
depth = depth.squeeze()
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
# Convert depth_min, depth_max, and shift to the resolved input unit.
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
# Normalization and quantization is performed in the resolved input unit.
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_u + shift_u))
log_max = math.log(float(depth_max_u + shift_u))
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
else:
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
if video_backend == "pyav":
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
write_u16_plane(frame.planes[0], quantized)
return frame
else:
return quantized
def dequantize_depth(
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
output_unit: Literal["m", "mm"] = "mm",
output_tensor: bool = True,
output_channel_last: bool = False,
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
"""Inverse of :func:`quantize_depth`.
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
Accepted input layouts :
- ``(H, W, 1)`` or ``(H, W)`` single frame with channel-last.
- ``(..., 1, H, W)`` batched frames with channel-first.
- ``(..., H, W, 1)`` batched frames with channel-last.
Output layout is determined by ``output_channel_last``.
Args:
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
``[depth_min, depth_max]``.
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
Returns:
Depth map in the requested unit and dtype.
Raises:
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if use_log:
_validate_log_quant_params(depth_min, shift)
if isinstance(quantized, av.VideoFrame):
quantized = quantized.to_ndarray(format=pix_fmt)
# Compute the scale and offset first.
depth_min_m = float(depth_min)
depth_max_m = float(depth_max)
shift_m = float(shift)
if use_log:
log_min = math.log(depth_min_m + shift_m)
log_max = math.log(depth_max_m + shift_m)
scale = (log_max - log_min) / DEPTH_QMAX
offset = log_min
else:
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
offset = depth_min_m
# ── Torch path: stay on the input device, single fp32 allocation. ────────
if isinstance(quantized, torch.Tensor):
if quantized.ndim >= 3:
# Drop the single-channel dimension so the math runs on (..., H, W).
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
# Single allocation we own; everything else is in-place.
buf = quantized.to(dtype=torch.float32, copy=True)
buf.mul_(scale).add_(offset)
if use_log:
buf.exp_().sub_(shift_m)
buf.clamp_(depth_min_m, depth_max_m)
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
if output_unit == "m":
return buf if output_tensor else buf.cpu().numpy()
# mm path: round + clamp in float32, skipping the uint16 round-trip
# when returning a tensor (torch.uint16 is poorly supported).
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
if output_tensor:
return buf
return buf.cpu().numpy().astype(np.uint16, copy=False)
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
arr = np.asarray(quantized)
if arr.ndim >= 3:
# Drop the single-channel dimension so the math runs on (..., H, W).
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
buf = np.empty(arr.shape, dtype=np.float32)
np.multiply(arr, scale, out=buf)
np.add(buf, offset, out=buf)
if use_log:
np.exp(buf, out=buf)
np.subtract(buf, shift_m, out=buf)
np.clip(buf, depth_min_m, depth_max_m, out=buf)
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
if output_unit == "m":
return torch.from_numpy(buf) if output_tensor else buf
np.multiply(buf, _MM_PER_METRE, out=buf)
np.rint(buf, out=buf)
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
if output_tensor:
# torch.uint16 support is very limited; return float32 millimetres.
return torch.from_numpy(buf)
return buf.astype(np.uint16, copy=False)
-1
View File
@@ -96,7 +96,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
depth_output_unit=cfg.dataset.depth_output_unit,
tolerance_s=cfg.tolerance_s,
)
else:
+1 -1
View File
@@ -336,7 +336,7 @@ def validate_feature_image_or_video(
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
expected_shape (list[str]): The expected shape (C, H, W).
value: The image data to validate.
Returns:
+5 -51
View File
@@ -42,41 +42,10 @@ def safe_stop_image_writer(func):
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
Behaviour by shape:
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
The native dtype is preserved using the matching PIL mode
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
(existing behaviour, gated by ``range_check``).
Other shapes / channel counts raise ``NotImplementedError`` or
``ValueError``.
"""
# TODO(CarolinePascal): 4 dimensions RGB-D images
if image_array.ndim not in (2, 3):
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
# caller emits (H, W), (1, H, W), or (H, W, 1).
if image_array.ndim == 3:
if image_array.shape[0] == 1:
image_array = image_array[0]
elif image_array.shape[-1] == 1:
image_array = image_array[..., 0]
if image_array.ndim == 2:
if image_array.dtype not in [np.uint16, np.float32]:
raise ValueError(
f"Unsupported single-channel image dtype: {image_array.dtype}. "
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
)
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
# 3D path: must be RGB (3 channels), channels-first or channels-last.
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
@@ -102,28 +71,13 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
return PIL.Image.fromarray(image_array)
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
"""
suffix = Path(fpath).suffix.lower()
if suffix == ".png":
return {"compress_level": compress_level}
if suffix in (".tif", ".tiff"):
return {"compression": "raw"}
return {}
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
"""
Saves a NumPy array or PIL Image to a file.
This function handles both NumPy arrays and PIL Image objects, converting
the former to a PIL Image before saving. It includes error handling for
the save operation. The output format is inferred from the *fpath*
extension: ``.png`` PNG with ``compress_level``, ``.tiff`` / ``.tif``
lossless raw depth maps (TIFF).
the save operation.
Args:
image (np.ndarray | PIL.Image.Image): The image data to save.
@@ -147,7 +101,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
img.save(fpath, compress_level=compress_level)
except Exception as e:
logger.error("Error writing image %s: %s", fpath, e)
+5 -28
View File
@@ -24,7 +24,7 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
from lerobot.configs import VideoEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
@@ -58,10 +58,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
download_videos: bool = True,
video_backend: str | None = None,
return_uint8: bool = False,
depth_output_unit: str = "mm",
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -188,9 +186,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
is used by the writer.
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults`
is used by the writer.
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
codec decide.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
@@ -213,7 +208,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
self._return_uint8 = return_uint8
self._depth_output_unit = depth_output_unit
self._batch_encoding_size = batch_encoding_size
self._encoder_threads = encoder_threads
@@ -254,7 +248,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
return_uint8=self._return_uint8,
depth_output_unit=self._depth_output_unit,
)
# Load actual data
@@ -280,7 +273,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = self._build_streaming_encoder(
self.meta.fps,
camera_encoder,
depth_encoder,
encoder_queue_maxsize,
encoder_threads,
)
@@ -288,7 +280,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
meta=self.meta,
root=self.root,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -324,7 +315,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=self.delta_timestamps,
image_transforms=self.image_transforms,
return_uint8=self._return_uint8,
depth_output_unit=self._depth_output_unit,
)
return self.reader
@@ -332,14 +322,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _build_streaming_encoder(
fps: int,
camera_encoder: VideoEncoderConfig | None,
depth_encoder: DepthEncoderConfig | None,
encoder_queue_maxsize: int,
encoder_threads: int | None,
) -> StreamingVideoEncoder:
return StreamingVideoEncoder(
fps=fps,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
@@ -536,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool | None = None,
private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
**card_kwargs,
@@ -555,8 +543,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
tag_version: If ``True``, create a Git tag for the current codebase
version.
push_videos: If ``False``, skip uploading the ``videos/`` directory.
private: If ``True``, create a private repository. If ``None``
(default), defer to the org default on the Hub (only affects orgs).
private: If ``True``, create a private repository.
allow_patterns: Glob pattern(s) restricting which files to upload.
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
of ``upload_folder`` for very large datasets.
@@ -658,7 +645,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
@@ -691,8 +677,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch-encoding videos. ``1`` means encode immediately.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
metadata_buffer_size: Number of episode metadata records to buffer
@@ -727,7 +711,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
obj._return_uint8 = False
obj._depth_output_unit = "mm"
obj._batch_encoding_size = batch_encoding_size
obj._encoder_threads = encoder_threads
@@ -737,13 +720,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -767,7 +749,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
@@ -797,8 +778,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch-encoding videos.
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
image_writer_processes: Subprocesses for async image writing.
@@ -826,7 +805,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
obj._return_uint8 = False
obj._depth_output_unit = "mm"
obj._batch_encoding_size = batch_encoding_size
if obj._requested_root is not None:
@@ -846,13 +824,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
camera_encoder=camera_encoder,
depth_encoder=depth_encoder,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
+3 -5
View File
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]],
*,
use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `exclude_images` and `patterns`, and finally
(image or state), filters them based on `use_videos` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset.
Args:
pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
use_videos: If False, image features are excluded.
patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter.
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
)
# 2. Apply filtering rules.
if is_image and exclude_images:
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, compiled_patterns):
continue
+2 -37
View File
@@ -24,7 +24,6 @@ import logging
from typing import Any
import av
import numpy as np
logger = logging.getLogger(__name__)
@@ -32,22 +31,6 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
height, width = src.shape
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
if fill_value is not None:
dst.fill(fill_value)
dst[:, :width] = src
@functools.cache
def get_pix_fmt_channels(pix_fmt: str) -> int:
"""Return the number of components (channels) for *pix_fmt*."""
return len(av.VideoFormat(pix_fmt).components)
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
@@ -109,7 +92,7 @@ def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Opti
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
) from e
elif isinstance(value, (float, int)):
num_val = float(value)
num_val = value
else:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
@@ -159,16 +142,6 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
)
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
"""Ensure *pix_fmt* can carry at least *channels* components."""
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
if pix_fmt_channels < channels:
raise ValueError(
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
f"but the source data has {channels} channel(s)."
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec)
@@ -183,18 +156,12 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
_check_option_value(vcodec, key, value, supported_options[key])
def check_video_encoder_parameters_pyav(
vcodec: str,
pix_fmt: str,
codec_options: dict[str, Any],
channels: int | None = None,
) -> None:
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict).
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises:
@@ -204,6 +171,4 @@ def check_video_encoder_parameters_pyav(
if not options:
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
_check_pixel_format(vcodec, pix_fmt)
if channels is not None:
_check_pix_fmt_channels(pix_fmt, channels)
_check_codec_options(vcodec, codec_options)
+32 -122
View File
@@ -14,36 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections.abc import Iterator
import numpy as np
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
"""Sampler over episode frames that stores only per-episode boundaries.
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
instead of materializing a Python list of every frame index.
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
resumed epoch, `__len__` still reports the full length.
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
starts (e.g. on resume or in tests).
"""
def __init__(
self,
dataset_from_indices: list[int],
@@ -52,125 +30,57 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
seed: int = 0,
):
"""
"""Sampler that optionally incorporates episode boundary information.
Args:
dataset_from_indices: Start index of each episode in the dataset.
dataset_to_indices: End index of each episode in the dataset.
episode_indices_to_use: Episode indices to use; None means all.
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
seed: Seed the permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
if from_indices.shape != to_indices.shape:
raise ValueError(
f"dataset_from_indices and dataset_to_indices must have the same length, "
f"got {len(from_indices)} and {len(to_indices)}"
)
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
used = np.ones(len(from_indices), dtype=bool)
if episode_indices_to_use is not None:
used = np.zeros(len(from_indices), dtype=bool)
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
starts = from_indices + drop_n_first_frames
lengths = to_indices - drop_n_last_frames - starts
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
to_indices[episode_idx] - from_indices[episode_idx],
drop_n_first_frames,
drop_n_last_frames,
)
used &= lengths > 0
if not used.any():
if not indices:
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self._starts = starts[used]
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.indices = indices
self.shuffle = shuffle
self.seed = seed
self._epoch = 0
self._start_index = 0
@property
def indices(self) -> list[int]:
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
self._epoch = epoch
def state_dict(self) -> dict:
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _epoch_generator(self, epoch: int) -> torch.Generator:
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
# and reproduces identically on every rank without touching the global RNG.
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
return torch.Generator().manual_seed(epoch_seed)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_epoch(epoch, start)
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
if self.shuffle:
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
for k in range(start, self._num_frames):
yield self._frame_index(int(order[k]))
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for k in range(start, self._num_frames):
yield self._frame_index(k)
for i in self.indices:
yield i
def __len__(self) -> int:
return self._num_frames
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
Assumptions (resume is only sample-exact when they hold):
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
many positions a step consumes, so the epoch/offset are wrong if either changed. The
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
the boundary is off.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
return {"epoch": epoch, "start_index": start_index}
return len(self.indices)
+1 -4
View File
@@ -87,14 +87,11 @@ DATA_DIR = "data"
VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
IMAGE_FILE_PATTERN = "frame-{frame_index:06d}.png"
DEPTH_FILE_PATTERN = "frame-{frame_index:06d}.tiff"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/" + IMAGE_FILE_PATTERN
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/" + DEPTH_FILE_PATTERN
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
+72 -172
View File
@@ -39,16 +39,11 @@ from datasets.features.features import register_feature
from PIL import Image
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import quantize_depth
from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__)
@@ -58,7 +53,6 @@ def decode_video_frames(
tolerance_s: float,
backend: str | None = None,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""
Decodes video frames using the specified backend.
@@ -70,35 +64,23 @@ def decode_video_frames(
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
accepted for one release as an alias for "pyav" and will be removed in a future version.
return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization.
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
is_depth (bool): Set to True if the video is a depth map (1 channel, uint12).
Returns:
torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12).
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
Currently supports torchcodec on cpu and pyav.
"""
if backend != "pyav" and is_depth:
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
# We do not actually return uint8 here, but we avoid the 255 normalization step.
return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True
)
if backend is None:
backend = get_safe_default_video_backend()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "pyav":
return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
)
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "video_reader":
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
)
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
else:
raise ValueError(f"Unsupported video backend: {backend}")
@@ -109,7 +91,6 @@ def decode_video_frames_pyav(
tolerance_s: float,
log_loaded_timestamps: bool = False,
return_uint8: bool = False,
is_depth: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video using PyAV.
@@ -128,9 +109,8 @@ def decode_video_frames_pyav(
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
decoded frame.
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
Otherwise, return float32 in [0, 1] range.
is_depth: Set to True if the video is a depth map (1 channel, uint12).
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
[0, 1] range.
Returns:
torch.Tensor of shape (len(timestamps), C, H, W).
@@ -160,13 +140,9 @@ def decode_video_frames_pyav(
current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
if is_depth:
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
else:
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
# Convert to CHW uint8 to match torchcodec's output layout.
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
# Convert to CHW uint8 to match torchcodec's output layout.
arr = frame.to_ndarray(format="rgb24") # H, W, 3
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
@@ -209,7 +185,7 @@ def decode_video_frames_pyav(
f"number of queried timestamps ({len(timestamps)})"
)
if return_uint8 or is_depth:
if return_uint8:
return closest_frames
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
@@ -430,38 +406,17 @@ def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
*,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
) -> None:
"""Encode a directory of image frames into an MP4 video.
When ``video_encoder`` is a :class:`~lerobot.configs.video.DepthEncoderConfig`,
frames are read from ``.tiff`` files and quantized to 12-bit depth codes using the
encoder's ``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``; otherwise ``.png``
RGB frames are encoded directly.
Args:
imgs_dir: Directory containing the frames to encode, named ``frame-000000``
onwards (``.png`` for RGB, ``.tiff`` for depth).
video_path: Output path for the encoded ``.mp4`` file.
fps: Frame rate of the output video.
video_encoder: Encoder settings (codec, pixel format, quality, ...). When
``None``, :func:`camera_encoder_defaults` is used. Pass a
:class:`~lerobot.configs.video.DepthEncoderConfig` to encode depth frames.
encoder_threads: Per-encoder thread count forwarded to the codec. ``None``
lets the codec decide.
log_level: libav log level to set while encoding, or ``None`` to leave the
current logging configuration unchanged.
overwrite: When ``False`` and ``video_path`` already exists, skip encoding and
log a warning. When ``True``, re-encode and replace the existing file.
"""
if video_encoder is None:
video_encoder = camera_encoder_defaults()
vcodec = video_encoder.vcodec
pix_fmt = video_encoder.pix_fmt
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
if camera_encoder is None:
camera_encoder = camera_encoder_defaults()
vcodec = camera_encoder.vcodec
pix_fmt = camera_encoder.pix_fmt
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
@@ -473,19 +428,17 @@ def encode_video_frames(
video_path.parent.mkdir(parents=True, exist_ok=True)
# Get input frames
is_depth = isinstance(video_encoder, DepthEncoderConfig)
suffix = ".png" if not is_depth else ".tiff"
template = "frame-" + ("[0-9]" * 6) + suffix
template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
if len(input_list) == 0:
raise FileNotFoundError(f"No images with suffix {suffix} found in {imgs_dir}.")
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
# Set logging level
if log_level is not None:
@@ -502,19 +455,8 @@ def encode_video_frames(
# Loop through input frames and encode them
for input_data in input_list:
with Image.open(input_data) as input_image:
if is_depth:
input_frame = quantize_depth(
np.array(input_image),
depth_min=video_encoder.depth_min,
depth_max=video_encoder.depth_max,
shift=video_encoder.shift,
use_log=video_encoder.use_log,
pix_fmt=video_encoder.pix_fmt,
video_backend="pyav",
)
else:
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
@@ -535,32 +477,23 @@ def encode_video_frames(
def reencode_video(
input_video_path: Path | str,
output_video_path: Path | str,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
start_time_s: float | None = None,
end_time_s: float | None = None,
) -> None:
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
"""Re-encode a video file using the given encoder configuration.
Args:
input_video_path: Existing video file to read.
output_video_path: Path for the re-encoded file.
video_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
start_time_s: When set, trim the output to start at this timestamp (seconds).
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
"""
video_encoder = video_encoder or camera_encoder_defaults()
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
camera_encoder = camera_encoder or camera_encoder_defaults()
output_video_path = Path(output_video_path)
@@ -570,9 +503,9 @@ def reencode_video(
output_video_path.parent.mkdir(parents=True, exist_ok=True)
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
vcodec = video_encoder.vcodec
pix_fmt = video_encoder.pix_fmt
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
vcodec = camera_encoder.vcodec
pix_fmt = camera_encoder.pix_fmt
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
@@ -593,10 +526,6 @@ def reencode_video(
width = int(in_stream.width)
height = int(in_stream.height)
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
if start_time_s is not None:
src.seek(int(start_time_s * av.time_base), backward=True)
with av.open(
tmp_output_video_path,
mode="w",
@@ -610,14 +539,7 @@ def reencode_video(
out_stream.height = height
for frame in src.decode(in_stream):
frame_time_s = frame.time
if start_time_s is not None and frame_time_s < start_time_s:
continue
if end_time_s is not None and frame_time_s >= end_time_s:
break
frame = frame.reformat(width=width, height=height, format=pix_fmt)
if start_time_s is not None:
frame.pts = None # reset timestamps so the trimmed output starts at t=0
packet = out_stream.encode(frame)
if packet:
dst.mux(packet)
@@ -754,21 +676,22 @@ class _CameraEncoderThread(threading.Thread):
self,
video_path: Path,
fps: int,
video_encoder: VideoEncoderConfig,
vcodec: str,
pix_fmt: str,
codec_options: dict[str, str],
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
encoder_threads: int | None = None,
):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.video_encoder = video_encoder
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.codec_options = codec_options
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.encoder_threads = encoder_threads
def run(self) -> None:
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
@@ -793,12 +716,12 @@ class _CameraEncoderThread(threading.Thread):
# Sentinel: flush and close
break
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
# Ensure HWC uint8 numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if not self.is_depth and frame_data.dtype != np.uint8:
if frame_data.dtype != np.uint8:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
@@ -806,29 +729,15 @@ class _CameraEncoderThread(threading.Thread):
height, width = frame_data.shape[:2]
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(
self.video_encoder.vcodec,
self.fps,
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
)
output_stream.pix_fmt = self.video_encoder.pix_fmt
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
output_stream.pix_fmt = self.pix_fmt
output_stream.width = width
output_stream.height = height
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
if not self.is_depth:
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
else:
video_frame = quantize_depth(
frame_data,
depth_min=self.video_encoder.depth_min,
depth_max=self.video_encoder.depth_max,
shift=self.video_encoder.shift,
use_log=self.video_encoder.use_log,
video_backend=self.video_encoder.video_backend,
)
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
@@ -887,26 +796,21 @@ class StreamingVideoEncoder:
self,
fps: int,
camera_encoder: VideoEncoderConfig | None = None,
depth_encoder: DepthEncoderConfig | None = None,
queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
Args:
fps: Frames per second for the output videos.
camera_encoder: Video encoder settings applied to all RGB cameras.
camera_encoder: Video encoder settings applied to all cameras.
When ``None``, :func:`camera_encoder_defaults` is used.
depth_encoder: Video encoder settings applied to all depth cameras,
including the depth quantization parameters. When ``None``,
:func:`depth_encoder_defaults` is used.
queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames.
encoder_threads: Number of encoder threads (global setting).
``None`` lets the codec decide.
queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames.
"""
self.fps = fps
self._camera_encoder = camera_encoder or camera_encoder_defaults()
self._depth_encoder = depth_encoder or depth_encoder_defaults()
self._encoder_threads = encoder_threads
self.queue_maxsize = queue_maxsize
@@ -919,25 +823,18 @@ class StreamingVideoEncoder:
self._episode_active = False
self._closed = False
def start_episode(
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
) -> None:
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
"""Start encoder threads for a new episode.
Args:
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
temp_dir: Base directory for temporary MP4 files
depth_video_keys: List of video or image feature keys that carry depth maps (e.g.
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
"""
if self._episode_active:
self.cancel_episode()
self._dropped_frames.clear()
if depth_video_keys is None:
depth_video_keys = []
for video_key in video_keys:
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
result_queue: queue.Queue = queue.Queue(maxsize=1)
@@ -946,15 +843,17 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
vcodec = self._camera_encoder.vcodec
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
video_encoder=encoder,
vcodec=vcodec,
pix_fmt=self._camera_encoder.pix_fmt,
codec_options=codec_options,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
encoder_threads=self._encoder_threads,
)
encoder_thread.start()
@@ -1161,23 +1060,15 @@ def get_audio_info(video_path: Path | str) -> dict:
def get_video_info(
video_path: Path | str,
video_encoder: VideoEncoderConfig | None = None,
camera_encoder: VideoEncoderConfig | None = None,
) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args:
video_path: Path to the encoded video file to probe.
video_encoder: If provided, record the exact encoder settings used to encode this
camera_encoder: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence encoder fields are only written for keys
not already populated from the video file itself. When a
:class:`~lerobot.configs.video.DepthEncoderConfig` is passed, the depth
quantization parameters (``depth_min`` / ``depth_max`` / ``shift`` /
``use_log``) are recorded so frames can be dequantized on read.
Returns:
The ``video.*`` / ``audio.*`` info dict, including ``is_depth_map`` which is
``True`` only when ``video_encoder`` is a
:class:`~lerobot.configs.video.DepthEncoderConfig`.
not already populated from the video file itself.
"""
logging.getLogger("libav").setLevel(av.logging.WARNING)
@@ -1195,10 +1086,13 @@ def get_video_info(
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
video_info["video.channels"] = pixel_channels
# Reset logging level
av.logging.restore_default_callback()
@@ -1207,18 +1101,27 @@ def get_video_info(
video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided
if video_encoder is not None:
for field_name, field_value in asdict(video_encoder).items():
if camera_encoder is not None:
for field_name, field_value in asdict(camera_encoder).items():
# vcodec is already populated from the video stream
if field_name == "vcodec":
continue
video_info.setdefault(f"video.{field_name}", field_value)
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.
@@ -1279,13 +1182,10 @@ class VideoEncodingManager:
img_dir = self.dataset.root / "images"
if img_dir.exists():
png_files = list(img_dir.rglob("*.png"))
tiff_files = list(img_dir.rglob("*.tiff"))
if len(png_files) == 0 and len(tiff_files) == 0:
if len(png_files) == 0:
shutil.rmtree(img_dir)
logger.debug("Cleaned up empty images directory")
else:
logger.debug(
f"Images directory is not empty, containing {len(png_files)} PNG and {len(tiff_files)} TIFF files"
)
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception
-16
View File
@@ -57,7 +57,6 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -158,10 +157,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
return MolmoAct2Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -216,8 +211,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return EO1Config(**kwargs)
elif policy_type == "molmoact2":
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -422,7 +415,6 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -440,14 +432,6 @@ def make_pre_post_processors(
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
+5 -85
View File
@@ -29,7 +29,6 @@ from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.__version__ import __version__
from lerobot.configs import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.utils.hub import HubMixin
@@ -39,67 +38,6 @@ from .utils import log_model_loading_keys
T = TypeVar("T", bound="PreTrainedPolicy")
def _build_card_context(
cfg: TrainPipelineConfig | None,
dataset_repo_id: str | None,
input_features: dict | None,
output_features: dict | None,
) -> dict:
"""Collect optional data for the model-card template.
Returns plain values only (no Markdown) the template in
``lerobot/templates/lerobot_modelcard_template.md`` decides how and whether to show
each one. Everything is best-effort: anything unavailable is left empty/None and the
template simply skips that section, so this never breaks a Hub push.
"""
context = {
"training": None,
"input_features": input_features or {},
"output_features": output_features or {},
"dataset": None,
"robot_type": None,
"cameras": [],
}
if cfg is not None:
optimizer = getattr(cfg, "optimizer", None)
context["training"] = {
"steps": cfg.steps,
"batch_size": cfg.batch_size,
"seed": cfg.seed,
"optimizer": getattr(optimizer, "type", None) if optimizer else None,
"lr": getattr(optimizer, "lr", None) if optimizer else None,
"lerobot_version": __version__,
}
if dataset_repo_id:
dataset_cfg = getattr(cfg, "dataset", None)
try:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(
dataset_repo_id,
root=getattr(dataset_cfg, "root", None),
revision=getattr(dataset_cfg, "revision", None),
)
context["dataset"] = {
"repo_id": dataset_repo_id,
"episodes": meta.total_episodes,
"frames": meta.total_frames,
"fps": meta.fps,
"tasks": [str(task) for task in meta.tasks.index],
}
context["robot_type"] = meta.robot_type
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
logging.warning(
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
f"omitted from the model card. ({e})"
)
return context
class ActionSelectKwargs(TypedDict, total=False):
noise: Tensor | None
@@ -290,7 +228,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
)
card.save(str(saved_path / "README.md"))
@@ -308,20 +246,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
self,
dataset_repo_id: str,
model_type: str,
license: str | None,
tags: list[str] | None,
cfg: TrainPipelineConfig | None = None,
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
) -> ModelCard:
base_model_mapping = {
"smolvla": "lerobot/smolvla_base",
"pi0": "lerobot/pi0_base",
"pi05": "lerobot/pi05_base",
"pi0_fast": "lerobot/pi0fast-base",
"xvla": "lerobot/xvla-base",
}
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
card_data = ModelCardData(
license=license or "apache-2.0",
@@ -330,20 +257,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
model_name=model_type,
datasets=dataset_repo_id,
base_model=base_model_mapping.get(model_type),
base_model=base_model,
)
context = _build_card_context(
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
)
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
context["base_model"] = base_model_mapping.get(model_type)
template_card = (
files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8")
)
card = ModelCard.from_template(card_data, template_str=template_card, **context)
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card
+1 -2
View File
@@ -126,8 +126,7 @@ def prepare_observation_for_inference(
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
if observation[name].dtype == torch.uint8:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
-1
View File
@@ -1 +0,0 @@
../../../../docs/source/policy_vla_jepa_README.md
-23
View File
@@ -1,23 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"make_vla_jepa_pre_post_processors",
]
@@ -1,337 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions
@@ -1,154 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
gripper_dim: int = 6
gripper_threshold: float = 0.5
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
@@ -1,629 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if reinitialized:
logging.warning(
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model
@@ -1,155 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes a gripper dimension after unnormalization.
Maps continuous value to {-1, 1}: > threshold -1, <= threshold 1 (matches starVLA convention).
Only applied when action has more dimensions than gripper_dim.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -1,117 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)
@@ -1,418 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)
+54 -278
View File
@@ -32,6 +32,7 @@ from __future__ import annotations
import importlib
import json
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
@@ -280,11 +281,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
_serialized_state_filenames: tuple[str | None, ...] | None = field(
default=None,
init=False,
repr=False,
)
def __call__(self, data: TInput) -> TOutput:
"""Processes input data through the full pipeline.
@@ -342,108 +338,30 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
transition = processor_step(transition)
yield transition
def _get_sanitized_name(self) -> str:
"""Return a filename-safe version of the pipeline name.
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `HubMixin`'s saving mechanism.
Returns:
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
config_filename = kwargs.pop("config_filename", None)
@staticmethod
def _get_state_filename(
*,
step_index: int,
registry_name: str | None,
sanitized_name: str,
) -> str:
"""Return the safetensors filename for one stateful processor step.
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
Args:
step_index: The index of the processor step in this pipeline.
registry_name: The registered processor step name, if available.
sanitized_name: The filename-safe pipeline name.
if config_filename is None:
config_filename = f"{sanitized_name}.json"
Returns:
The state filename used by the existing disk serialization format.
"""
if registry_name:
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
return f"{sanitized_name}_step_{step_index}.safetensors"
@staticmethod
def _get_state_key(state_filename: str) -> str:
"""Return the in-memory state key for a serialized state filename.
Args:
state_filename: The `.safetensors` filename from the serialized config.
Returns:
The state key used by the in-memory pipeline state dictionary.
"""
return state_filename.removesuffix(".safetensors")
@staticmethod
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
"""Return serialized state filenames in step order.
Args:
loaded_config: A validated processor pipeline config.
Returns:
A tuple containing each step's serialized state filename, or None for stateless steps.
"""
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
"""Return expected state filenames in step order for `load_state_dict()`.
Returns:
The preserved serialized state filenames when available, otherwise filenames derived from
current non-empty step state.
"""
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
self.steps
):
return self._serialized_state_filenames
sanitized_name = self._get_sanitized_name()
state_filenames: list[str | None] = []
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
state_filenames.append(None)
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filenames.append(
self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
)
return tuple(state_filenames)
def get_config(self) -> dict[str, Any]:
"""Return the JSON-serializable pipeline configuration.
Returns:
A dictionary with the same content that `save_pretrained()` writes as JSON.
"""
sanitized_name = self._get_sanitized_name()
pipeline_config: dict[str, Any] = {
config: dict[str, Any] = {
"name": self.name,
"steps": [],
}
# Iterate through each step to build its configuration entry.
for step_index, processor_step in enumerate(self.steps):
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
step_entry: dict[str, Any] = {}
# Prefer registry name for portability, otherwise fall back to full class path.
if registry_name:
step_entry["registry_name"] = registry_name
else:
@@ -451,110 +369,31 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
step_entry["config"] = processor_step.get_config()
# Save step configuration if `get_config` is implemented.
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
step_state_dict = processor_step.state_dict()
if step_state_dict:
step_entry["state_file"] = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
# Save step state if `state_dict` is implemented and returns a non-empty dict.
if hasattr(processor_step, "state_dict"):
state = processor_step.state_dict()
if state:
# Clone tensors to avoid modifying the original state.
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
pipeline_config["steps"].append(step_entry)
# Create a unique filename for the state file.
if registry_name:
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
return pipeline_config
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
"""Return pipeline state tensors grouped by state key.
config["steps"].append(step_entry)
Returns:
A dictionary mapping suffixless state keys to cloned step state dictionaries.
"""
sanitized_name = self._get_sanitized_name()
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filename = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
state_key = self._get_state_key(state_filename)
pipeline_state_dict[state_key] = {
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
}
return pipeline_state_dict
def load_state_dict(
self,
state_dict: dict[str, dict[str, torch.Tensor]],
) -> None:
"""Load pipeline state tensors into the existing steps.
Args:
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
Raises:
KeyError: If loading finds missing expected state or unexpected extra state.
"""
expected_state_filenames = self._get_state_filenames_for_loading()
used_state_keys: set[str] = set()
for step_index, (processor_step, state_filename) in enumerate(
zip(self.steps, expected_state_filenames, strict=True)
):
if state_filename is None:
continue
state_key = self._get_state_key(state_filename)
if state_key not in state_dict:
raise KeyError(
f"Missing state key '{state_key}' for processor step {step_index}. "
f"Available state keys: {sorted(state_dict.keys())}"
)
processor_step.load_state_dict(state_dict[state_key])
used_state_keys.add(state_key)
unexpected_state_keys = set(state_dict) - used_state_keys
if unexpected_state_keys:
expected_state_key_set = {
self._get_state_key(state_filename)
for state_filename in expected_state_filenames
if state_filename is not None
}
raise KeyError(
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
f"Expected state keys: {sorted(expected_state_key_set)}"
)
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
"""Internal method to comply with `HubMixin`'s saving mechanism.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
config_filename = kwargs.pop("config_filename", None)
sanitized_name = self._get_sanitized_name()
if config_filename is None:
config_filename = f"{sanitized_name}.json"
pipeline_config = self.get_config()
pipeline_state_dict = self.state_dict()
for state_key, step_state_dict in pipeline_state_dict.items():
state_filename = f"{state_key}.safetensors"
save_file(step_state_dict, save_directory / state_filename)
with open(save_directory / config_filename, "w") as file_pointer:
json.dump(pipeline_config, file_pointer, indent=2)
# Write the main configuration JSON file.
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
def save_pretrained(
self,
@@ -738,54 +577,12 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
cls._validate_overrides_used(validated_overrides, loaded_config)
# 5. Construct and return the final pipeline instance
pipeline = cls(
return cls(
steps=steps,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
return pipeline
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[TInput], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
) -> DataProcessorPipeline[TInput, TOutput]:
"""Build a pipeline from an in-memory config and optional state tensors.
Args:
config: A config dictionary with the same structure as the saved processor JSON.
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
overrides: Optional constructor overrides keyed by registry name or class name.
to_transition: Optional converter from input data to `EnvTransition`.
to_output: Optional converter from `EnvTransition` to output data.
Returns:
A processor pipeline built from the config and optional state.
"""
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
cls._validate_overrides_used(remaining_override_keys, config)
pipeline = cls(
steps=steps,
name=config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
if state_dict is not None:
pipeline.load_state_dict(state_dict)
return pipeline
@classmethod
def _load_config(
@@ -869,7 +666,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
) from e
@classmethod
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
def _validate_loaded_config(
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
) -> None:
"""Validate that a config was loaded and is a valid processor config.
This method validates processor config format with intelligent migration detection:
@@ -889,7 +688,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Args:
model_id: The model identifier (used for migration detection)
loaded_config: The loaded config value to validate (may be non-dict)
loaded_config: The loaded config dictionary (guaranteed non-None)
config_filename: The config filename that was loaded (for error messages)
Raises:
@@ -903,14 +702,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
model_id,
f"Config file '{config_filename}' is not a valid processor configuration",
)
loaded_config_description = (
list(loaded_config.keys())
if isinstance(loaded_config, dict)
else type(loaded_config).__name__
)
raise ValueError(
f"Config file '{config_filename}' is not a valid processor configuration. "
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
)
@classmethod
@@ -972,41 +766,26 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
ImportError: If a step class cannot be imported or found in registry
ValueError: If a step cannot be instantiated with its configuration
"""
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
return steps, remaining_override_keys
@classmethod
def _build_steps_from_config(
cls,
loaded_config: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[list[ProcessorStep], set[str]]:
"""Build processor steps from config without loading tensor state.
Args:
loaded_config: The loaded processor configuration.
overrides: User-provided constructor overrides keyed by step key.
Returns:
A tuple containing instantiated steps and override keys that did not match a step.
"""
processor_steps: list[ProcessorStep] = []
remaining_override_keys = set(overrides.keys())
steps: list[ProcessorStep] = []
override_keys = set(overrides.keys())
for step_entry in loaded_config["steps"]:
# 1. Get step class and key
step_class, step_key = cls._resolve_step_class(step_entry)
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
if step_key in remaining_override_keys:
remaining_override_keys.discard(step_key)
# 2. Instantiate step with overrides
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
processor_steps.append(processor_step)
# 3. Load step state if available
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
return processor_steps, remaining_override_keys
# 4. Track used overrides
if step_key in override_keys:
override_keys.discard(step_key)
steps.append(step_instance)
return steps, override_keys
@classmethod
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
@@ -1317,7 +1096,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
return True
@classmethod
def _is_processor_config(cls, config: Any) -> bool:
def _is_processor_config(cls, config: dict) -> bool:
"""Check if config follows DataProcessorPipeline format.
This method validates the processor configuration structure:
@@ -1368,9 +1147,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Returns:
True if config follows valid DataProcessorPipeline format, False otherwise
"""
if not isinstance(config, dict):
return False
# Must have a "steps" field with a list of step configurations
if not isinstance(config.get("steps"), list):
return False
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
return actions
@ProcessorStepRegistry.register("relative_actions_processor")
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
-2
View File
@@ -20,14 +20,12 @@ from .factory import (
make_reward_pre_post_processors as make_reward_pre_post_processors,
)
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
__all__ = [
# Configuration classes
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
"TOPRewardConfig",
# Base class
+2 -16
View File
@@ -25,7 +25,6 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig
@@ -39,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
Args:
name: The name of the reward model. Supported names are "reward_classifier",
"sarm", "robometer", "topreward".
"sarm", "topreward".
Returns:
The reward model class corresponding to the given name.
@@ -55,10 +54,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "robometer":
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
return RobometerRewardModel
elif name == "topreward":
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
@@ -79,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
Args:
reward_type: The type of the reward model. Supported types include
"reward_classifier", "sarm", "robometer", "topreward".
"reward_classifier", "sarm", "topreward".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -92,8 +87,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RewardClassifierConfig(**kwargs)
elif reward_type == "sarm":
return SARMConfig(**kwargs)
elif reward_type == "robometer":
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
else:
@@ -175,13 +168,6 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(reward_cfg, RobometerConfig):
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
return make_robometer_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, TOPRewardConfig):
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
-19
View File
@@ -1,19 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_robometer import RobometerConfig
from .modeling_robometer import RobometerRewardModel
from .processor_robometer import make_robometer_pre_post_processors
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]
@@ -1,320 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute per-frame Robometer progress and success curves for a LeRobot dataset.
For each episode, builds per-frame sub-samples using the frame-steps
strategy from the Robometer eval server: for each original frame ``t``,
linspace-subsample ``[0, t]`` into ``K`` frames (default 4, matching
``NUM_SUBSAMPLED_FRAMES`` in the eval server), run one forward through
the Robometer processor + model, and keep the last-frame progress value.
All sub-samples are the same size ``K`` so they batch cleanly.
The parquet uses the same schema as SARM's
:mod:`lerobot.rewards.sarm.compute_rabc_weights` so existing consumers
:class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
``progress_sparse``) and the progress-overlay script in
``examples/dataset/create_progress_videos.py`` work without modification.
Usage:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes with batching
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
from lerobot.types import TransitionKey
DEFAULT_OUTPUT_FILENAME = "robometer_progress.parquet"
# Upstream Robometer eval server uses K=4 for frame-steps sub-samples.
DEFAULT_NUM_SUBSAMPLED_FRAMES = 4
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
"""Read ``reward_model_path`` from parquet metadata if available."""
if not parquet_path.exists():
return None
try:
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
if metadata and b"reward_model_path" in metadata:
return metadata[b"reward_model_path"].decode()
except Exception: # nosec B110
return None
return None
def _resolve_task(sample: dict[str, Any], default: str) -> str:
"""Best-effort task extraction from a dataset sample."""
task = sample.get("task")
if isinstance(task, str) and task:
return task
return default
def _build_subsample_indices(num_frames: int, num_subsampled_frames: int) -> list[np.ndarray]:
"""Frame-steps linspace expansion.
For each ``t in [0, num_frames - 1]`` returns ``num_subsampled_frames``
indices from ``np.linspace(0, t, num_subsampled_frames)`` the first
and last frames are always included. Each entry is a fixed-size array
so the model can batch them.
"""
return [np.linspace(0, t, num_subsampled_frames).round().astype(np.int64) for t in range(num_frames)]
def compute_robometer_progress(
dataset_repo_id: str,
reward_model_path: str,
output_path: str | None = None,
device: str = "cuda",
batch_size: int = 32,
num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES,
episodes: list[int] | None = None,
image_key: str | None = None,
) -> Path:
"""Run Robometer over a dataset and write per-frame progress + success."""
logging.info(f"Loading Robometer: {reward_model_path}")
config = RobometerConfig(pretrained_path=reward_model_path, device=device)
if image_key is not None:
config.image_key = image_key
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
model.to(device).eval()
encoder = RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=num_subsampled_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
)
image_key = config.image_key
logging.info(f"Loading dataset: {dataset_repo_id}")
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
logging.info(f"Processing {len(episode_indices)} episode(s)")
all_index: list[int] = []
all_episode: list[int] = []
all_frame: list[int] = []
all_progress: list[float] = []
for episode_idx in tqdm(episode_indices, desc="Episodes"):
ep = dataset.meta.episodes[episode_idx]
ep_start = int(ep["dataset_from_index"])
ep_end = int(ep["dataset_to_index"])
num_frames = ep_end - ep_start
if num_frames <= 0:
continue
first_sample = dataset[ep_start]
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
ep_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
sub_indices = _build_subsample_indices(num_frames, num_subsampled_frames)
progress_per_frame = np.zeros(num_frames, dtype=np.float32)
for start in tqdm(range(0, num_frames, batch_size), desc=f" Ep {episode_idx}", leave=False):
end = min(start + batch_size, num_frames)
frames_batch = torch.stack([ep_frames[sub_indices[i]] for i in range(start, end)])
transition = {
TransitionKey.OBSERVATION: {image_key: frames_batch},
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
}
encoded = encoder(transition)
obs = encoded[TransitionKey.OBSERVATION]
batch = {
key: value.to(device) if isinstance(value, torch.Tensor) else value
for key, value in obs.items()
}
with torch.no_grad():
rewards = model.compute_reward(batch)
progress_per_frame[start:end] = rewards.cpu().numpy()
for local in range(num_frames):
all_index.append(ep_start + local)
all_episode.append(episode_idx)
all_frame.append(local)
all_progress.append(float(progress_per_frame[local]))
if device.startswith("cuda"):
torch.cuda.empty_cache()
table = pa.table(
{
"index": np.asarray(all_index, dtype=np.int64),
"episode_index": np.asarray(all_episode, dtype=np.int64),
"frame_index": np.asarray(all_frame, dtype=np.int64),
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
}
).replace_schema_metadata({b"reward_model_path": reward_model_path.encode()})
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out)
logging.info(f"Saved {len(table)} frame values to {out}")
progress_arr = np.asarray(all_progress, dtype=np.float32)
if progress_arr.size:
logging.info(
f"Progress: mean={float(progress_arr.mean()):.4f}, "
f"std={float(progress_arr.std()):.4f}, "
f"min={float(progress_arr.min()):.4f}, "
f"max={float(progress_arr.max()):.4f}"
)
return out
def main():
parser = argparse.ArgumentParser(
description="Compute per-frame Robometer progress curves for RA-BC weighting.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes, smaller batches for memory-constrained GPUs
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
""",
)
parser.add_argument(
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
)
parser.add_argument(
"--reward-model-path", type=str, default=None, help="Robometer checkpoint repo id or local path."
)
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
parser.add_argument(
"--batch-size", type=int, default=32, help="Sub-samples per Qwen forward (default: 32)."
)
parser.add_argument(
"--num-subsampled-frames",
type=int,
default=DEFAULT_NUM_SUBSAMPLED_FRAMES,
help=f"Frames per sub-sample (default: {DEFAULT_NUM_SUBSAMPLED_FRAMES}, matches eval server).",
)
parser.add_argument(
"--episodes", type=int, nargs="+", default=None, help="Process only these episode indices."
)
parser.add_argument(
"--image-key", type=str, default=None, help="Image observation key (default: from config)."
)
parser.add_argument(
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
reward_model_path = args.reward_model_path
if reward_model_path is None:
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
parquet_path = Path(temp_dataset.root) / DEFAULT_OUTPUT_FILENAME
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
if reward_model_path:
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
else:
raise ValueError(
"--reward-model-path is required (no existing parquet with model metadata found)."
)
output_path = compute_robometer_progress(
dataset_repo_id=args.dataset_repo_id,
reward_model_path=reward_model_path,
output_path=args.output_path,
device=args.device,
batch_size=args.batch_size,
num_subsampled_frames=args.num_subsampled_frames,
episodes=args.episodes,
image_key=args.image_key,
)
print(f"\nRobometer progress saved to: {output_path}")
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = DEFAULT_OUTPUT_FILENAME
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(output_path),
path_in_repo=hub_path,
repo_id=args.dataset_repo_id,
repo_type="dataset",
)
print(
"Successfully uploaded to: "
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
)
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
print(" rabc_head_mode: sparse")
else:
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: {output_path}")
print(" rabc_head_mode: sparse")
if __name__ == "__main__":
main()
@@ -1,158 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.rewards import RewardModelConfig
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoTokenizer
else:
AutoConfig = None # type: ignore[assignment]
AutoTokenizer = None # type: ignore[assignment]
# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time.
# The order is part of the data contract: upstream resized ``embed_tokens``
# after adding these tokens in this exact order, so changing the set or order
# would silently misalign the saved embedding rows with their token ids.
# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream
# heads (never read at inference) but still occupy rows the checkpoint expects.
ROBOMETER_SPECIAL_TOKENS = (
"<|split_token|>",
"<|reward_token|>",
"<|pref_token|>",
"<|sim_token|>",
"<|prog_token|>",
)
@RewardModelConfig.register_subclass("robometer")
@dataclass
class RobometerConfig(RewardModelConfig):
"""Configuration for the Robometer reward model."""
pretrained_path: str | None = "lerobot/Robometer-4B"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
reward_output: str = "progress" # "progress" or "success"
success_threshold: float = 0.5
license: str | None = "apache-2.0"
tags: list[str] | None = field(
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
)
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
torch_dtype: str = "bfloat16"
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
average_temporal_patches: bool = True
frame_pooling: str = "mean" # "mean" | "boundary" | "attention"
frame_pooling_attn_temperature: float = 1.0
progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete"
progress_discrete_bins: int = 10
# Serialised Qwen backbone config (post-resize). Always populated by
# ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it
# is non-empty after construction. Saved into ``config.json`` automatically
# by the base ``_save_pretrained``.
vlm_config: dict[str, Any] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self) -> None:
super().__post_init__()
if self.reward_output not in {"progress", "success"}:
raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}")
if self.max_frames is not None and self.max_frames < 1:
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
if self.frame_pooling not in {"mean", "boundary", "attention"}:
raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}")
if self.frame_pooling_attn_temperature <= 0:
raise ValueError("frame_pooling_attn_temperature must be > 0")
if self.progress_loss_type not in {"l1", "l2", "discrete"}:
raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}")
if self.use_per_frame_progress_token and not self.use_multi_image:
raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True")
if self.image_key not in self.input_features:
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
# Deterministically populate ``vlm_config`` so it is non-empty after
# construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives
# ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize
# vocab the published ``Robometer-4B`` checkpoint was saved with.
if not self.vlm_config:
require_package("transformers", extra="robometer")
vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict()
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
text_config = vlm.get("text_config")
if not isinstance(text_config, dict):
raise ValueError(
f"Backbone config for {self.base_model_id!r} has no nested `text_config`; "
"Robometer expects a Qwen-VL-style config."
)
text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)
self.vlm_config = vlm
@property
def use_discrete_progress(self) -> bool:
"""Whether the progress head outputs distribution logits over bins."""
return self.progress_loss_type.lower() == "discrete"
@property
def vlm_backbone_config(self):
"""Reconstruct the Qwen backbone config from :attr:`vlm_config`."""
require_package("transformers", extra="robometer")
config_dict = deepcopy(self.vlm_config)
model_type = config_dict.pop("model_type", None)
if model_type is None:
raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config")
return AutoConfig.for_model(model_type, **config_dict)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> None:
return None
@property
def reward_delta_indices(self) -> None:
return None
def validate_features(self) -> None:
if self.image_key not in self.input_features:
raise ValueError(f"Robometer requires image input feature {self.image_key!r}")
@@ -1,481 +0,0 @@
# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons.
Paper: https://arxiv.org/abs/2603.02115
Project: https://robometer.github.io
Original code: https://github.com/aliang8/robometer
Model: https://huggingface.co/robometer/Robometer-4B
Robometer is a general-purpose, video-language-input reward model built on
``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction
objective:
- A frame-level progress loss anchoring reward magnitude on expert data.
- A trajectory-comparison preference loss imposing global ordering constraints
across trajectories sharing the same instruction.
To support downstream RL it also predicts a frame-level binary success. The
training prompt inserts three learnable tokens:
- ``<|prog_token|>`` after each frame to read per-frame progress and success.
- ``<|pref_token|>`` at the end to read pairwise preference (training-only).
- ``<|split_token|>`` between two trajectories in preference samples
(training-only).
Progress is modeled as a categorical distribution over ``progress_discrete_bins``
uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate
is recovered as the softmax-weighted mean of those centers see
:func:`convert_bins_to_continuous`.
This LeRobot port is **inference-only**: the preference head is preserved in
the state dict for byte-equivalence with the published ``Robometer-4B``
checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`,
which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd
success probability depending on :attr:`RobometerConfig.reward_output`.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import torch
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModelForImageTextToText
else:
AutoModelForImageTextToText = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
# Namespace for Robometer's pre-encoded Qwen-VL observation tensors.
ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer."
ROBOMETER_QWEN_INPUT_KEYS = (
"input_ids",
"attention_mask",
"pixel_values",
"pixel_values_videos",
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"mm_token_type_ids",
)
ROBOMETER_METADATA_KEYS = (
"prog_token_id",
"vision_start_token_id",
"vision_end_token_id",
"video_merge_size",
)
ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS
def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor:
"""Collapse per-bin logits into a single value in ``[0, 1]``.
The discrete progress head outputs ``num_bins`` logits per frame. Bins are
evenly spaced centers in ``[0, 1]``; the continuous prediction is the
softmax-weighted mean of those centers.
"""
bin_probs = torch.softmax(bin_logits, dim=-1)
num_bins = bin_logits.shape[-1]
bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype)
return (bin_probs * bin_centers).sum(dim=-1)
def _squeeze_last_safe(x: Tensor) -> Tensor:
"""Drop a trailing singleton dim only when present."""
return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x
def _torch_dtype(name: str) -> torch.dtype:
dtype = getattr(torch, name, None)
if isinstance(dtype, torch.dtype):
return dtype
raise ValueError(f"Unknown torch dtype: {name!r}")
class RobometerPredictionHead(nn.Sequential):
"""Small MLP head used for Robometer's progress / success / preference outputs."""
def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None:
layers: list[nn.Module] = [
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_size),
]
if with_sigmoid:
layers.append(nn.Sigmoid())
super().__init__(*layers)
def decode_progress_outputs(
progress_logits: Tensor | None,
success_logits: Tensor | None,
*,
is_discrete_mode: bool,
) -> dict[str, list[list[float]]]:
"""Decode RBM head outputs into per-frame floats.
Args:
progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities.
is_discrete_mode: if True the progress logits get a softmax over bins
and are projected onto bin centers via :func:`convert_bins_to_continuous`.
Returns:
Dict with ``progress_pred`` and ``success_probs``, each a list of
length ``B`` of per-frame float lists.
"""
progress_pred: list[list[float]] = []
success_probs: list[list[float]] = []
if progress_logits is not None:
for sample_logits in progress_logits:
if is_discrete_mode:
continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu())
progress_pred.append(continuous.flatten().tolist())
else:
progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist())
if success_logits is not None:
for sample_logits in success_logits:
success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist())
return {"progress_pred": progress_pred, "success_probs": success_probs}
class RobometerRewardModel(PreTrainedRewardModel):
"""Robometer (RBM) reward model — inference-only LeRobot port.
Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three
prediction heads from the paper (progress, success, preference). At
inference time only the progress and success heads are queried; the
preference head is kept on the module so the published ``Robometer-4B``
safetensors load unchanged.
"""
name = "robometer"
config_class = RobometerConfig
def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None:
require_package("transformers", extra="robometer")
super().__init__(config)
self.config = config
# Two backbone-build paths (EO-1 style, branched on ``pretrained_path``):
#
# - Fresh training (``pretrained_path is None``): download the base
# Qwen weights and resize the embed table to match
# ``vlm_config.text_config.vocab_size`` — populated deterministically
# in ``RobometerConfig.__post_init__`` as
# ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``
#
# - Loading a saved checkpoint (``pretrained_path`` is set): rebuild
# the empty architecture from ``vlm_config`` via
# ``AutoModelForImageTextToText.from_config`` so the subsequent
# ``model.safetensors`` load is a direct fill of the right shape —
# no redundant Qwen weight download.
torch_dtype = _torch_dtype(config.torch_dtype)
if config.pretrained_path is None:
self.model = AutoModelForImageTextToText.from_pretrained(
config.base_model_id,
dtype=torch_dtype,
trust_remote_code=True,
)
target_vocab = config.vlm_config["text_config"]["vocab_size"]
self.model.resize_token_embeddings(target_vocab)
else:
self.model = AutoModelForImageTextToText.from_config(
config.vlm_backbone_config,
dtype=torch_dtype,
trust_remote_code=True,
)
# All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`.
# Falls back to the top-level `hidden_size` so future non-multimodal
# variants would still resolve.
backbone_config = self.model.config
text_config = getattr(backbone_config, "text_config", None)
hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None
if hidden_size is None:
hidden_size = getattr(backbone_config, "hidden_size", None)
if hidden_size is None:
raise AttributeError(
f"Could not infer hidden_size from backbone config of {config.base_model_id}"
)
hidden_dim = int(hidden_size)
# Robometer's three prediction heads + frame-pool attention.
progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1
self.progress_head = RobometerPredictionHead(
hidden_dim,
progress_output,
dropout=dropout,
with_sigmoid=not config.use_discrete_progress,
)
self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False)
# Match the dtype of the loaded base model so weight loading is a no-op cast.
model_dtype = next(self.model.parameters()).dtype
self.progress_head.to(dtype=model_dtype)
self.preference_head.to(dtype=model_dtype)
self.success_head.to(dtype=model_dtype)
self.frame_pool_attn.to(dtype=model_dtype)
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
inputs = {
key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"]
for key in ROBOMETER_INPUT_KEYS
if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch
}
if "input_ids" not in inputs:
raise KeyError(
f"Robometer batch missing pre-encoded inputs (expected "
f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the "
"RobometerEncoderProcessorStep ran before `compute_reward`."
)
device = next(self.model.parameters()).device
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
self.eval()
with torch.no_grad():
progress_logits, success_logits = self._compute_rbm_logits(inputs)
decoded = decode_progress_outputs(
progress_logits,
success_logits,
is_discrete_mode=self.config.use_discrete_progress,
)
values = (
decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"]
)
rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values])
if self.config.reward_output == "success":
rewards = (rewards > self.config.success_threshold).float()
else:
# Match upstream Robometer's ``extract_rewards_from_output``: per-frame
# progress predictions are clamped to ``[0, 1]`` before being returned.
rewards = rewards.clamp(0.0, 1.0)
return rewards.to(self.config.device or "cpu")
def _compute_rbm_logits(
self,
inputs: dict[str, Any],
) -> tuple[Tensor, Tensor]:
"""Run the Qwen3-VL backbone and apply Robometer's heads.
``inputs`` is the encoded batch produced by
:class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well
as Robometer-specific metadata (``prog_token_id``,
``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``)
the metadata is popped here so the rest can be forwarded straight to
the Qwen model.
Returns ``(progress_logits, success_logits)``. Shapes:
- ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
- ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time).
"""
prog_token_id = inputs.pop("prog_token_id", None)
vision_start_token_id = inputs.pop("vision_start_token_id", None)
vision_end_token_id = inputs.pop("vision_end_token_id", None)
video_merge_size = inputs.pop("video_merge_size", 14)
# Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the
# full hidden-state tuple and take the last layer. This matches the
# `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main).
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
hidden_state = (
outputs.hidden_states[-1]
if getattr(outputs, "hidden_states", None)
else outputs.last_hidden_state
)
input_ids = inputs["input_ids"]
if self.config.use_per_frame_progress_token:
if prog_token_id is None:
raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)")
return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id)
if self.config.use_multi_image:
if vision_start_token_id is None or vision_end_token_id is None:
raise KeyError(
"`vision_start_token_id` / `vision_end_token_id` missing in batch "
"(run RobometerEncoderProcessorStep first)"
)
return self._process_multi_image_frames(
hidden_state,
input_ids,
start_id=vision_start_token_id,
end_id=vision_end_token_id,
)
video_grid_thw = inputs.get("video_grid_thw")
if video_grid_thw is None:
raise ValueError("video_grid_thw is required for video-mode Robometer inference")
if vision_start_token_id is None:
raise KeyError("`vision_start_token_id` missing in batch")
return self._process_video_frames(
hidden_state,
input_ids,
video_grid_thw,
start_id=vision_start_token_id,
merge_size=video_merge_size,
)
def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]:
"""Apply progress + success heads to a tensor of frame embeddings."""
progress_out = self.progress_head(frame_embeddings)
progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out)
success = _squeeze_last_safe(self.success_head(frame_embeddings))
return progress, success
def _process_token_extraction(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
prog_token_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success from ``<|prog_token|>`` positions."""
token_mask = input_ids == prog_token_id
batch_indices, positions = token_mask.nonzero(as_tuple=True)
if positions.numel() == 0:
raise ValueError("`<|prog_token|>` not found in any sequence")
per_sample_hidden = [
hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0])
]
progress_list, success_list = [], []
for embeddings in per_sample_hidden:
if embeddings.shape[0] == 0:
raise ValueError("`<|prog_token|>` missing in a sequence")
progress, success = self._apply_heads_to_hidden_states(embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _process_multi_image_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
start_id: int,
end_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in multi-image mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
frame_embeddings = self._extract_hidden_states_from_token_pairs(
seq_hidden, seq_ids, start_id, end_id
)
progress, success = self._apply_heads_to_hidden_states(frame_embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _extract_hidden_states_from_token_pairs(
self,
hidden_state: Tensor,
input_ids: Tensor,
start_id: int,
end_id: int,
) -> Tensor:
start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0]
end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
if start_positions.numel() != end_positions.numel():
raise ValueError(
f"Mismatched vision token counts: {start_positions.numel()} start vs "
f"{end_positions.numel()} end"
)
frames: list[Tensor] = []
for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True):
if start >= end:
raise ValueError(f"Invalid vision token pair: start={start} end={end}")
patch_tokens = hidden_state[start + 1 : end]
if patch_tokens.shape[0] == 0:
frames.append((hidden_state[start] + hidden_state[end]) / 2.0)
continue
pooling = self.config.frame_pooling
if pooling == "mean":
frames.append(patch_tokens.mean(dim=0))
elif pooling == "boundary":
frames.append(patch_tokens[-1])
else: # attention
scores = (
self.frame_pool_attn(patch_tokens).squeeze(-1)
/ self.config.frame_pooling_attn_temperature
)
weights = torch.softmax(scores, dim=0).unsqueeze(-1)
frames.append((weights * patch_tokens).sum(dim=0))
return torch.stack(frames)
def _process_video_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
video_grid_thw: Tensor,
*,
start_id: int,
merge_size: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in video mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist())
tokens_per_frame = (h_dim * w_dim) // (merge_size**2)
cursor = start_positions[0].item()
frame_embeddings: list[Tensor] = []
for _ in range(t_dim):
if self.config.average_temporal_patches:
patch = seq_hidden[cursor : cursor + tokens_per_frame]
frame_embeddings.append(patch.mean(dim=0))
else:
frame_embeddings.append(seq_hidden[cursor + tokens_per_frame])
cursor += tokens_per_frame
stacked = torch.stack(frame_embeddings)
progress, success = self._apply_heads_to_hidden_states(stacked)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
@@ -1,338 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Robometer pre/post processing pipelines."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from PIL import Image
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
policy_action_to_transition,
)
from lerobot.rewards.robometer.configuration_robometer import (
ROBOMETER_SPECIAL_TOKENS,
RobometerConfig,
)
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor
else:
AutoProcessor = None
PROGRESS_PROMPT = (
"The task for the robot is '{task}'. Given the trajectory video, predict "
"the task progress at each frame, how far along the robot is towards "
"completing the task, a float between 0 and 1, where 0 is the starting "
"state and 1 is when the task is completed. If the robot is not "
"performing the same task, predict 0 progress."
)
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
if frames.ndim != 4:
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
if frames.dtype != np.uint8:
frames = np.clip(frames, 0, 255).astype(np.uint8)
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array."""
if max_frames is not None:
video = video[-max_frames:]
if video.shape[1] in (1, 3):
video = video.permute(0, 2, 3, 1)
elif video.shape[-1] not in (1, 3):
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
array = video.detach().cpu().numpy()
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
array = array * 255.0
return np.clip(array, 0, 255).astype(np.uint8)
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
if task is None:
task = default
if task is None:
raise KeyError("Robometer expected a task description in complementary data")
if isinstance(task, str):
return [task] * batch_size
if isinstance(task, tuple):
task = list(task)
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}")
if len(task) == 1 and batch_size > 1:
return task * batch_size
if len(task) != batch_size:
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
return task
@dataclass
@ProcessorStepRegistry.register(name="robometer_encoder")
class RobometerEncoderProcessorStep(ProcessorStep):
"""Encode raw frames + task into Qwen-VL tensors for the Robometer model.
Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and
registers Robometer's special tokens on the tokenizer. The matching
embedding resize happens model-side in
:meth:`RobometerRewardModel.__init__`.
At call time the step reads:
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
- ``complementary_data[task_key]``: a string or list of strings.
and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}<name>"]`` for:
- the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``,
``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ...
- Robometer-specific token ids consumed by the model heads:
``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``,
``video_merge_size``.
"""
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
max_length: int = 1024
_processor: Any = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
require_package("transformers", extra="robometer")
require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils")
self._processor = AutoProcessor.from_pretrained(
self.base_model_id,
trust_remote_code=True,
do_sample_frames=False,
padding_side="right",
)
# Register Robometer's special tokens on the tokenizer. The matching
# embedding resize happens model-side in `RobometerRewardModel.__init__`.
tokenizer = self._processor.tokenizer
# Qwen tokenizers may not define a pad token, but batched prompts/videos
# require padding, so reuse EOS as the padding token.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
for token in ROBOMETER_SPECIAL_TOKENS:
if token not in tokenizer.get_vocab():
tokenizer.add_special_tokens({"additional_special_tokens": [token]})
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if not isinstance(observation, dict):
raise ValueError("RobometerEncoderProcessorStep requires an observation dict")
if self.image_key not in observation:
raise KeyError(f"Robometer expected image key {self.image_key!r} in observation")
frames = observation[self.image_key]
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
if tensor.ndim == 4:
tensor = tensor.unsqueeze(1)
elif tensor.ndim != 5:
raise ValueError(
f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
)
batch_size = tensor.shape[0]
tasks = _expand_tasks(
complementary.get(self.task_key, self.default_task),
batch_size=batch_size,
default=self.default_task,
)
samples = [
(_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size)
]
encoded = self.encode_samples(samples)
new_observation = dict(observation)
for key, value in encoded.items():
new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]:
"""Run the Qwen-VL processor on a list of ``(frames, task)`` samples."""
from qwen_vl_utils import process_vision_info
conversations = [self._build_conversation(frames, task) for frames, task in samples]
texts = [
self._processor.apply_chat_template(
msg,
tokenize=False,
add_generation_prompt=False,
add_vision_id=True,
enable_thinking=False,
fps=1,
)
for msg in conversations
]
process_kwargs: dict[str, Any] = {
"return_video_kwargs": True,
"return_video_metadata": True,
}
image_processor = getattr(self._processor, "image_processor", None)
if image_processor is not None and hasattr(image_processor, "patch_size"):
process_kwargs["image_patch_size"] = image_processor.patch_size
image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs)
videos: list[Any] | None = None
video_metadatas: list[Any] | None = None
if video_inputs:
if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2:
videos_seq, metadatas_seq = zip(*video_inputs, strict=False)
videos = list(videos_seq)
video_metadatas = list(metadatas_seq)
else:
videos = list(video_inputs)
processor_kwargs: dict[str, Any] = {
"text": texts,
"images": image_inputs,
"padding": True,
"truncation": False,
"max_length": self.max_length,
"return_tensors": "pt",
"do_resize": False,
}
if videos is not None:
processor_kwargs["videos"] = videos
if video_metadatas is not None:
processor_kwargs["video_metadata"] = video_metadatas
if video_kwargs:
processor_kwargs.update(video_kwargs)
encoded = self._processor(**processor_kwargs)
# Write Robometer-specific token ids and the video patch merge size into
# the encoded batch so `RobometerRewardModel` doesn't need its own
# tokenizer at inference (EO1-style separation: the processor owns the
# tokenizer, the model owns the backbone and heads).
tokenizer = self._processor.tokenizer
encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>")
encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>")
encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>")
video_processor = getattr(self._processor, "video_processor", None)
encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14))
return encoded
def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]:
pil_frames = _frames_to_pil(frames)
prompt = PROGRESS_PROMPT.format(task=task)
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
if self.use_multi_image:
for image in pil_frames:
content.append({"type": "image", "image": image})
if self.use_per_frame_progress_token:
content.append({"type": "text", "text": "<|prog_token|>"})
else:
content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0})
return [{"role": "user", "content": content}]
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"base_model_id": self.base_model_id,
"image_key": self.image_key,
"task_key": self.task_key,
"default_task": self.default_task,
"max_frames": self.max_frames,
"use_multi_image": self.use_multi_image,
"use_per_frame_progress_token": self.use_per_frame_progress_token,
"max_length": self.max_length,
}
def make_robometer_pre_post_processors(
config: RobometerConfig,
dataset_stats: dict[str, dict[str, Any]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
The preprocessor adds a batch dimension if needed, runs Robometer's
encoder, and moves everything to the configured device. The
postprocessor is the identity since Robometer outputs a single reward
tensor.
"""
del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor.
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=config.max_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__)
class BiOpenArmFollower(BimanualMixin, Robot):
class BiOpenArmFollower(Robot):
"""
Bimanual OpenArm Follower Arms
"""
@@ -40,17 +39,15 @@ class BiOpenArmFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
# will only open the cameras assigned to it. Per-arm cameras are used
# as fallback when top-level cameras are empty.
if config.cameras:
left_cameras = config.cameras
right_cameras = {}
else:
left_cameras = config.left_arm_config.cameras
right_cameras = config.right_arm_config.cameras
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
@@ -59,7 +56,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=left_cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
@@ -78,7 +75,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
cameras=right_cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
@@ -98,19 +95,22 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@property
def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return {
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
# "right_wrist"), so we merge them directly — unlike motors which need the
# left_/right_ prefix to disambiguate identical per-arm joint names.
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -120,6 +120,27 @@ class BiOpenArmFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -127,15 +148,21 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Camera keys that should NOT get the arm prefix (they already have unique names)
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
right_obs = self.right_arm.get_observation()
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict
@@ -162,4 +189,9 @@ class BiOpenArmFollower(BimanualMixin, Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -32,7 +32,5 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
# Top-level cameras shared across both arms.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__)
class BiRebotB601Follower(BimanualMixin, Robot):
class BiRebotB601Follower(Robot):
"""Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -42,18 +41,6 @@ class BiRebotB601Follower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -62,7 +49,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -99,12 +86,10 @@ class BiRebotB601Follower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
return {
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -114,13 +99,32 @@ class BiRebotB601Follower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
for k, v in self.left_arm.get_observation().items():
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
obs_dict = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
return obs_dict
@check_if_not_connected
@@ -139,3 +143,8 @@ class BiRebotB601Follower(BimanualMixin, Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()},
}
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -29,8 +27,3 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -28,7 +27,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__)
class BiSOFollower(BimanualMixin, Robot):
class BiSOFollower(Robot):
"""
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -40,18 +39,6 @@ class BiSOFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -59,7 +46,7 @@ class BiSOFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
)
right_arm_config = SOFollowerRobotConfig(
@@ -90,12 +77,13 @@ class BiSOFollower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -105,21 +93,42 @@ class BiSOFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
return obs_dict
@@ -142,3 +151,8 @@ class BiSOFollower(BimanualMixin, Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..so_follower import SOFollowerConfig
@@ -29,8 +27,3 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
+3 -12
View File
@@ -68,12 +68,9 @@ class SOFollower(Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
features: dict[str, tuple] = {}
for cam in self.cameras:
features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
if getattr(self.cameras[cam], "use_depth", False):
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
return features
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -193,12 +190,6 @@ class SOFollower(Robot):
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
if getattr(cam, "use_depth", False):
start = time.perf_counter()
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
return obs_dict
@check_if_not_connected

Some files were not shown because too many files have changed in this diff Show More