Compare commits

..

18 Commits

Author SHA1 Message Date
hq-fang c78023dae7 sync uv.lock with main 2026-05-27 15:49:59 +00:00
hq-fang 36d0ba5127 validate molmoact2 gripper range 2026-05-22 22:14:51 +00:00
hq-fang dca792951e fix molmoact2 pre-commit checks 2026-05-22 22:14:50 +00:00
hq-fang 0a369e104a skip molmoact2 tests without optional deps 2026-05-22 22:14:50 +00:00
hq-fang b0cdf99957 format molmoact2 files 2026-05-22 22:14:50 +00:00
hq-fang 733f9768b5 lazy import molmoact2 scipy 2026-05-22 22:14:50 +00:00
hq-fang 7fe49f9e54 load molmoact2 without remote code 2026-05-22 22:14:50 +00:00
hq-fang e1afb96474 fix molmoact2 hf image key resolution 2026-05-22 22:14:50 +00:00
hq-fang f395f36dec move molmoact2 config logic into config 2026-05-22 22:14:50 +00:00
hq-fang 738ba9272f use a single molmoact2 action queue 2026-05-22 22:14:50 +00:00
hq-fang 2a0495f8c3 add scipy dependency to molmoact2 extra 2026-05-22 22:14:50 +00:00
hq-fang c3c9c2b089 guard molmoact2 processor transformers import 2026-05-22 22:14:50 +00:00
hq-fang e13c6a6110 guard molmoact2 transformers imports 2026-05-22 22:14:50 +00:00
hq-fang 140cf2a420 remove molmoact2 processor override from factory 2026-05-22 22:14:50 +00:00
hq-fang c092194cf2 align molmoact2 feature validation with eo pattern 2026-05-22 22:14:50 +00:00
hq-fang b858ba1b6c simplify molmoact2 package imports 2026-05-22 22:14:50 +00:00
hq-fang e870af119f add apache headers to molmoact2 files 2026-05-22 22:14:50 +00:00
hq-fang 4174c3b303 add molmoact2 policy 2026-05-22 22:14:49 +00:00
143 changed files with 165 additions and 23676 deletions
-6
View File
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
--env.episode_length=5 \
--eval.n_episodes=1 \
--eval.batch_size=1
# E2E annotation pipeline smoke test against a tiny in-memory fixture
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
# backend, so it does not require a real model checkpoint or GPU.
annotation-e2e:
uv run python -m tests.annotations.run_e2e_smoke
-8
View File
@@ -9,8 +9,6 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: lelab
title: LeLab - Lerobot GUI
- local: bring_your_own_policies
title: Adding a Policy
- local: integrate_hardware
@@ -45,8 +43,6 @@
title: Language Columns and Recipes
- local: tools
title: Tools
- local: annotation_pipeline
title: Annotation Pipeline
- local: video_encoding_parameters
title: Video encoding parameters
- local: streaming_video_encoding
@@ -65,8 +61,6 @@
title: π₀.₅ (Pi05)
- local: molmoact2
title: MolmoAct2
- local: vla_jepa
title: VLA-JEPA
- local: eo1
title: EO-1
- local: groot
@@ -81,8 +75,6 @@
- sections:
- local: sarm
title: SARM
- local: robometer
title: ROBOMETER
- local: topreward
title: TOPReward
title: "Reward Models"
-281
View File
@@ -1,281 +0,0 @@
# Annotation Pipeline
`lerobot-annotate` watches each episode's video with a vision-language
model (VLM) and writes natural-language annotations back into your
dataset. It fills the two language columns from the
[Language Columns and Recipes](./language_and_recipes) page —
`language_persistent` and `language_events` — straight into
`data/chunk-*/file-*.parquet`.
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
memory, interjections, speech, and visual Q&A that a policy can be
trained on.
## How it fits together
```text
your dataset lerobot-annotate
(LeRobot v3.1)
┌─────────────────────────────────────────────────────┐
│ read episodes │
└──────────────────────────┬──────────────────────────┘
┌────────────────────┼────────────────────┐
▼ ▼ ▼
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
└────────────────────┼─────────────────────┘
│ each module stages raw JSONL
▼ into .annotate_staging/
┌─────────────────┐
│ validator │ ◀── checks everything
└────────┬────────┘
┌─────────────────┐
│ writer │
└────────┬────────┘
data/chunk-*/file-*.parquet
(+ meta/info.json tools)
```
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
VLM. Each module stages its output to disk, a validator checks it, and a
single writer rewrites the dataset shards in place.
## What the pipeline produces
Each module emits a few kinds of annotation ("styles"), routed to one of
the two language columns:
| Style / atom | Column | Module |
| ------------------------------------------- | --------------------- | --------------- |
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
| `interjection` | `language_events` | `interjections` |
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
### How subtasks are generated
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
it uses a two-step **describe → segment** flow:
1. **Describe** — the VLM narrates only what it actually sees in the
chosen camera (no guessing about the task).
2. **Segment** — that description is fed back in, and the VLM splits the
episode into consecutive atomic subtasks.
The resulting spans are then stitched into a gap-free, full-episode
cover, so **every frame has exactly one active subtask**. See
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
for the production settings (single camera, embedded frames, windowed
subtask generation).
### Tools
The writer does **not** add a `tools` column to the parquet. The tool
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
After every run, the pipeline makes sure the canonical `say` schema is in
that list, keeping any tools you declared beforehand.
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
pipeline preserves whatever is already there. That makes the tool visible
to the chat template, so the model can learn to _generate_ the call. The
runtime layer that actually _executes_ a generated call (the `Tool`
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
this PR — the [Tools](./tools) doc marks those pieces as
not-yet-implemented.
## Running on Hugging Face Jobs
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
The repo ships a launcher script you copy and tweak for your dataset:
```bash
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
```
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
that:
1. installs `lerobot` (from `main`) plus the annotation extras,
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
drives it over the OpenAI-compatible API,
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
with `lerobot-annotate`,
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
back to `--repo_id` in place if you leave that unset).
To use a different dataset, model, or hub repo, edit the `CMD` block in
the script. Every flag there maps directly to a `lerobot-annotate` flag
(run `lerobot-annotate --help` for the full list).
## Key options
These are the flags you'll reach for most often. Run
`lerobot-annotate --help` for everything else; the defaults are tuned for
short manipulation episodes.
### Dataset in / out
| Flag | Default | What it does |
| ----------------- | ------- | ----------------------------------------------------------------------- |
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
| `--root` | — | Annotate a local dataset directory instead. |
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
### Which modules run
Every module is on by default and can be toggled independently (set to
`false` to skip it, e.g. to iterate on one module at a time):
| Flag | Default | Turns off |
| ------------------------- | ------- | ----------------------------------- |
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
| `--interjections.enabled` | `true` | interjections + speech atoms |
| `--vqa.enabled` | `true` | the VQA pairs |
### The VLM (`--vlm.*`)
| Flag | Default | What it does |
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
| `--vlm.temperature` | `0.2` | Sampling temperature. |
### Subtasks / plan / memory (`--plan.*`)
| Flag | Default | What it does |
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
| `--plan.frames_per_second` | `1.0` | How densely the episode video is sampled. |
| `--plan.max_video_frames` | `32` | Hard cap on frames per call (context-budget guard — don't exceed ~32 for a 32k context). |
| `--plan.subtask_window_seconds` | `0` | Split long episodes into fixed windows for constant frame density (`0` = whole episode). |
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
| `--plan.use_video_url` | `false` | Send a server-side video clip instead of embedded frames. |
### Interjections + VQA
| Flag | Default | What it does |
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
## Contributing new modules
The pipeline is built to grow, and **contributions are very welcome** —
a brand-new module (say, trajectory traces or affordances), a new prompt
template, a smarter grounding flow, or quality fixes to the existing
`plan` / `interjections` / `vqa` modules.
Every module lives under
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
client and the keyframe cache, writes its raw output to the staging
tree, and plugs into the executor as its own phase. Got an idea? Open an
issue or PR on [the repo](https://github.com/huggingface/lerobot).
## How recipes consume the output
The annotations are meant to be read by recipes (see
[Language Columns and Recipes](./language_and_recipes)). Typically:
- low-level / high-level / memory-update branches read
`subtask` / `plan` / `memory` from `language_persistent`.
- an interjection-response branch reads `interjection` events plus the
paired speech atom (merged into one assistant turn via `tool_calls_from`)
and the matching `plan` refresh at the same timestamp.
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
`language_events`.
## Why state and events are split
Two ideas shape the design:
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
`plan`, `memory`) apply to the whole episode and answer "what's true
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
the one frame whose timestamp matches. Timestamps are copied straight
from the source parquet — never recomputed in floating point.
2. **One VLM pass.** All three modules share a single VLM client (the
OpenAI-compatible client talking to the job's vLLM server), so you pay
for one model load per dataset, not three.
## Re-running a single module
Each module stages its raw output to
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
prompt iteration cheap: re-running one module overwrites only its own
JSONL, then the writer recomposes the final parquet. Disable modules you
don't want with `--plan.enabled=false` (and likewise
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
## What the validator checks
Before the writer runs, `StagingValidator` confirms:
- every event row lands exactly on a real frame timestamp;
- no speech / interjection pairs are left orphaned;
- `plan` is refreshed at every interjection timestamp;
- `memory` rows fall on subtask boundaries (a warning, not an error);
- each VQA assistant `content` is valid JSON in one of the
bbox / keypoint / count / attribute / spatial shapes;
- every row goes to the column chosen by `column_for_style(style)`.
Any error aborts the writer. Pass `--skip_validation=true` to override
while debugging.
## Where each module's ideas come from
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
for atom granularity ("pick up one piece of lettuce", "place bowl to
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
for "how, not what" detail.
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
keep only the minimal relevant information — preserve outcomes, drop
specific attributes.
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
situated correction, specific constraint, preference. Speech is a
tool-call-only atom
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
keypoints) and Steerable VLA Policies
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
grounding. Pi0.7 also grounds answers across abstraction levels.
When improving a module, tweak its prompt template in
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
rewriting from scratch.
## Roughly how much it costs
Per episode, the pipeline makes about `max_steps` plan calls,
`max_interjections_per_episode` interjection calls, and
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
~50 VLM calls.
Storage stays small: `language_persistent` is at most tens of KB per
episode (parquet dictionary-encodes the one entry that repeats across
frames), and `language_events` is empty on most frames — its size scales
with the number of emissions, not `num_frames × num_emissions`.
-1
View File
@@ -647,6 +647,5 @@ The `--strategy.type` flag selects the execution mode:
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
- `episodic`: Episode-oriented policy recording with reset phases between episodes
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
-38
View File
@@ -157,44 +157,6 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
### Episodic (`--strategy.type=episodic`)
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
```bash
lerobot-rollout \
--strategy.type=episodic \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so100_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/my_eval_data \
--dataset.num_episodes=20 \
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10 \
--dataset.single_task="Pick up the red cube"
```
Teleop is optional — if omitted the robot holds its position during the reset phase.
**Keyboard controls:**
| Key | Action |
| ----------- | -------------------------------- |
| `→` (right) | End the current episode early |
| `←` (left) | Discard episode and re-record it |
| `ESC` | Stop the recording session |
| Flag | Description |
| ----------------------------------------------- | -------------------------------------------------------------------------- |
| `--dataset.num_episodes` | Number of episodes to record |
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
---
## Inference Backends
-5
View File
@@ -141,11 +141,6 @@ sample["target_message_indices"]
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
## Blends
Blend recipes select one weighted sub-recipe deterministically from the sample index.
`recipes/subtasks_vqa.yaml` trains the core blend — high-level subtask prediction, low-level execution, and VQA. `recipes/subtask_mem_vqa_speech.yaml` is the fuller variant that also adds memory updates and spoken interjection responses.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
-29
View File
@@ -1,29 +0,0 @@
# LeLab - LeRobot Guide
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
> [!TIP]
> For now LeLab is compatible only with SO-ARM101
<Youtube id="VqyKUuW9V1g" />
### Installation
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
```
uv tool install git+https://github.com/huggingface/leLab.git && lelab
```
After install, run `lelab` from your terminal anytime to start the app.
### Features
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
+1 -1
View File
@@ -275,7 +275,7 @@ A converter aggregates perepisode files into larger shards and writes episode
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
# Convert an existing v2.1 dataset hosted on the Hub:
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
```
**What it does**
+1 -1
View File
@@ -238,7 +238,7 @@ your dataset has not been converted with quantile statistics, you can add them
with:
```bash
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=your_dataset
```
+1 -1
View File
@@ -91,7 +91,7 @@ lerobot-train \
If your dataset is not converted with `quantiles`, you can convert it with the following command:
```bash
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=your_dataset \
```
-39
View File
@@ -1,39 +0,0 @@
# VLA-JEPA
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
---
## Architecture Overview
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
At inference time only the Qwen backbone and action head are used; the world model is not needed.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
+1 -1
View File
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
If you have existing datasets in v2.1 format, use the migration tool:
```bash
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id your_id/existing_dataset
```
-185
View File
@@ -1,185 +0,0 @@
# ROBOMETER
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
**Project**: [robometer.github.io](https://robometer.github.io/)
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
## Overview
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
- **Success head**: predicts per-frame task success probability.
- **Preference head**: predicts which of two trajectories better completes the task during training.
The paper trains ROBOMETER with a composite objective:
```text
L = L_pref + L_prog + L_succ
```
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
## What the LeRobot Integration Covers
- Standard `reward_model.type=robometer` configuration through LeRobot.
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
- Dense, frame-level progress and success predictions internally.
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install the ROBOMETER dependencies:
```bash
pip install -e ".[robometer]"
```
If you use `uv` directly from a source checkout:
```bash
uv sync --extra robometer
```
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
## Model Inputs and Outputs
ROBOMETER expects:
- A trajectory video or sequence of frames.
- A natural-language task description.
In LeRobot datasets, the preprocessor reads:
| Config field | Default | Meaning |
| ------------------------- | ------------------------ | ----------------------------------------------------- |
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
## Usage
### Load the Reward Model Directly
```python
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
cfg = RobometerConfig(
pretrained_path="lerobot/Robometer-4B",
device="cuda",
reward_output="progress",
)
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
```
### Encode Frames and Compute a Reward
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
```python
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
# task: str
encoder = RobometerEncoderProcessorStep(
base_model_id=cfg.base_model_id,
use_multi_image=cfg.use_multi_image,
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
max_frames=cfg.max_frames,
)
encoded = encoder.encode_samples([(frames, task)])
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
reward = reward_model.compute_reward(batch)
```
`reward` is a tensor of shape `(batch_size,)`.
### Use the Reward Factory
You can also instantiate ROBOMETER through the reward factory:
```python
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
cfg = make_reward_model_config(
"robometer",
pretrained_path="lerobot/Robometer-4B",
device="cuda",
image_key="observation.images.top",
)
reward_model = make_reward_model(cfg)
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
```
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
## Configuration Notes
### Backbone and Vocabulary
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
```text
<|split_token|>
<|reward_token|>
<|pref_token|>
<|sim_token|>
<|prog_token|>
```
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
### Progress Prediction
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
### Success Prediction
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
## Limitations
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
## References
- [ROBOMETER project](https://robometer.github.io/)
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
- [Original ROBOMETER code](https://github.com/robometer/robometer)
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
## Citation
```bibtex
@inproceedings{liang2026robometer,
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
year={2026},
booktitle={Robotics: Science and Systems 2026},
}
```
## License
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
-235
View File
@@ -1,235 +0,0 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Default behaviour — view padding / trimming (no action required)**
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
-109
View File
@@ -1,109 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
Spawns one single-GPU ``h200`` job that:
1. installs ``lerobot`` from ``main`` plus the annotation extras,
2. boots one vllm server with Qwen3.6-27B (dense VLM),
3. runs the plan / interjections / vqa modules across the dataset
in free-form mode (each episode generates its own subtasks +
memory),
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
or back to ``--repo_id``.
Usage:
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
run. For larger datasets, scale to ``h200x4`` and raise
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
"""
import os
from huggingface_hub import get_token, run_job
token = os.environ.get("HF_TOKEN") or get_token()
if not token:
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
CMD = (
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
"pip install --no-deps "
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
"pip install --upgrade-strategy only-if-needed "
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
"openai && "
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
"export VLLM_VIDEO_BACKEND=pyav && "
"lerobot-annotate "
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated5 "
"--push_to_hub=true "
"--vlm.backend=openai "
"--vlm.model_id=Qwen/Qwen3.6-27B "
"--vlm.parallel_servers=1 "
"--vlm.num_gpus=1 "
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
"--tensor-parallel-size 1 --max-model-len 32768 "
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
"--vlm.serve_ready_timeout_s=1800 "
"--vlm.client_concurrency=128 "
"--vlm.max_new_tokens=512 "
"--vlm.temperature=0.7 "
"--executor.episode_parallelism=16 "
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
"--vlm.camera_key=observation.images.robot0_agentview_right "
# Phase 1 — plan module (subtasks + memory).
# Embed decoded frames (not a file:// clip): if clip extraction fails,
# the video_url path silently sends no video and the VLM hallucinates.
"--plan.use_video_url=false "
"--plan.frames_per_second=1.0 "
# 32 frames ≈ 8-10k vision tokens, fits the 32768 context. Don't push
# toward 128 — that overflows the context (BadRequestError 400).
"--plan.max_video_frames=32 "
# Window long episodes into 32s chunks (constant 1 fps density) so they
# get more subtasks; per-window spans are merged + stitched. 0 disables.
"--plan.subtask_window_seconds=32 "
# RoboCasa: the dataset task string is authoritative (eval uses it), so
# keep it driving subtasks. ``always`` would throw it away and hallucinate.
"--plan.derive_task_from_video=off "
# No task augmentation: eval conditions on the exact task strings, so
# rephrasings are unused at best and harmful when they drift.
"--plan.n_task_rephrasings=0 "
# Keep subtask decomposition tight for atomic tasks.
"--plan.plan_max_steps=10 "
# Only subtasks + memory — skip the numbered "plan" rows. true re-enables.
"--plan.emit_plan=false "
# The describe->segment grounding pass (+1 VLM call/episode) is ON by
# default; pass --plan.subtask_describe_first=false to skip it.
# Phase 2 — interjections + speech.
"--interjections.max_interjections_per_episode=6 "
# Phase 4 — general VQA: disabled for this run.
"--vqa.enabled=false"
)
job = run_job(
image="vllm/vllm-openai:latest",
command=["bash", "-c", CMD],
flavor="h200",
secrets={"HF_TOKEN": token},
timeout="2h",
)
print(f"Job URL: {job.url}")
print(f"Job ID: {job.id}")
+2 -38
View File
@@ -85,11 +85,6 @@ dependencies = [
"termcolor>=2.4.0,<4.0.0",
"tqdm>=4.66.0,<5.0.0",
# Training utilities
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
# pure-python dependency — preferred over a hand-rolled implementation.
"ema-pytorch>=0.7.7,<1.0.0",
# Build tools (required by opencv-python-headless on some platforms)
"cmake>=3.29.0.1,<4.2.0",
"setuptools>=71.0.0,<81.0.0",
@@ -147,7 +142,6 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
sentencepiece-dep = ["sentencepiece>=0.2.0,<0.3.0"] # FAST action tokenizer backend (pi052, pi0_fast)
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
@@ -203,7 +197,7 @@ wallx = [
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]", "lerobot[sentencepiece-dep]"]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
@@ -218,40 +212,15 @@ groot = [
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
# which talks to any OpenAI-compatible server (``vllm serve`` /
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
# (see examples/annotations/run_hf_job.py).
annotations = [
"lerobot[dataset]",
"lerobot[transformers-dep]",
"openai>=1.40,<2.0",
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
# uv's single unified lock would then cap ``torch`` for every extra
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
# install it locally only if you run your own ``vllm serve``.
]
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
# are isolated so adding a new tool doesn't bloat the base install.
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
tools = [
"pocket-tts>=1.0.0,<3.0.0",
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
@@ -312,7 +281,6 @@ all = [
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -323,7 +291,6 @@ all = [
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
@@ -346,10 +313,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
# ---------------- Tool Configurations ----------------
@@ -367,7 +331,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
[tool.setuptools.package-data]
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
-15
View File
@@ -1,15 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -1,36 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Steerable annotation pipeline producing ``language_persistent`` and
``language_events`` columns for LeRobot datasets.
The pipeline is decomposed into three independently runnable modules whose
outputs are staged per-episode before a final parquet rewrite:
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
"""
from .config import AnnotationPipelineConfig
from .validator import StagingValidator, ValidationReport
from .writer import LanguageColumnsWriter
__all__ = [
"AnnotationPipelineConfig",
"LanguageColumnsWriter",
"StagingValidator",
"ValidationReport",
]
@@ -1,196 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class PlanConfig:
"""``plan`` module: subtasks + plan + memory + task augmentation."""
enabled: bool = True
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
n_task_rephrasings: int = 10
# Derive the task from video instead of episode_task: off / if_short / always.
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
derive_task_from_video: str = "if_short"
derive_task_min_words: int = 3
# Frames sampled uniformly, capped at max_video_frames — a hard context cap
# (~300 tokens/frame, so 32 fit a 32k VLM; 128 overflow).
frames_per_second: float = 1.0
max_video_frames: int = 32
# >0: split long episodes into windows of this length (constant fps density)
# instead of subsampling the whole episode; spans merged + stitched. 0 disables.
subtask_window_seconds: float = 0.0
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
# Narrate-only grounding pass before segmenting — best defense against subtasks
# invented from the task text (+1 VLM call/episode).
subtask_describe_first: bool = True
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
emit_plan: bool = True
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
# Send a server-side ``video_url`` clip (at use_video_url_fps) instead of embedded frames.
use_video_url: bool = False
use_video_url_fps: float = 1.0
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
@dataclass
class TaskAugAxesConfig:
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
enabled: bool = False
synonym_paraphrase: int = 3
omit_arm: int = 3
omit_orientation: int = 2
omit_grasp_method: int = 2
combined_omissions: int = 2
@dataclass
class InterjectionsConfig:
"""``interjections`` module: interjections + paired speech."""
enabled: bool = True
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
max_interjections_per_episode: int = 3
interjection_min_t: float = 2.0
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
interjection_window_seconds: float = 2.0
interjection_window_frames: int = 4
@dataclass
class VqaConfig:
"""``vqa`` module: general VQA."""
enabled: bool = True
vqa_emission_hz: float = 1.0
K: int = 1
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
# True: ground VQA only on --vlm.camera_key (default: every camera).
restrict_to_default_camera: bool = False
@dataclass
class VlmConfig:
"""Shared Qwen-VL client configuration."""
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
# auto_serve=True); ``stub`` is for tests.
backend: str = "openai"
model_id: str = "Qwen/Qwen3.6-27B"
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
api_base: str = "http://localhost:8000/v1"
api_key: str = "EMPTY"
# Spawn a server if none answers api_base; False = fail fast on a remote.
auto_serve: bool = True
serve_port: int = 8000
# Override the auto-serve command; ``{port}`` substituted per replica.
serve_command: str | None = None
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
parallel_servers: int = 1
num_gpus: int = 0
client_concurrency: int = 16
serve_ready_timeout_s: float = 600.0
max_new_tokens: int = 512
temperature: float = 0.2
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
max_model_len: int | None = None
# Camera for keyframes; None → first ``observation.images.*`` key.
camera_key: str | None = None
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
chat_template_kwargs: dict[str, Any] | None = None
@dataclass
class ExecutorConfig:
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
# Episodes processed concurrently per phase; main knob for saturating the servers.
episode_parallelism: int = 16
@dataclass
class AnnotationPipelineConfig:
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
# is on and ``new_repo_id`` unset.
repo_id: str | None = None
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
new_repo_id: str | None = None
root: Path | None = None
# Defaults to ``<root>/.annotate_staging/``.
staging_dir: Path | None = None
seed: int = 1729
plan: PlanConfig = field(default_factory=PlanConfig)
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
vqa: VqaConfig = field(default_factory=VqaConfig)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None
# Keyframe decode backend. None → ffmpeg CLI (crash-/thread-safe; torchcodec
# SIGSEGVs under concurrent decode). Or ``"torchcodec"`` / ``"pyav"``.
video_backend: str | None = None
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
push_to_hub: bool = False
push_private: bool = False
push_commit_message: str | None = None
def resolved_staging_dir(self, root: Path) -> Path:
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
@@ -1,253 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""In-process executor that runs the annotation phases.
The executor runs **six phases** in dependency order:
phase 1: ``plan`` module (plan + subtasks + memory)
phase 2: ``interjections`` module (interjections + speech)
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: ``vqa`` module (VQA)
phase 5: validator
phase 6: writer
Phase 3 is why the ``plan`` module must be re-entered after the
``interjections`` module — to refresh ``plan`` rows at interjection
timestamps.
Distributed execution is provided by Hugging Face Jobs (see
``examples/annotations/run_hf_job.py``); the runner inside the job
invokes ``lerobot-annotate`` which uses this in-process executor.
Episode-level concurrency is controlled by
``ExecutorConfig.episode_parallelism``.
"""
from __future__ import annotations
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .config import AnnotationPipelineConfig
from .reader import EpisodeRecord, iter_episodes
from .staging import EpisodeStaging
from .validator import StagingValidator
from .writer import LanguageColumnsWriter
logger = logging.getLogger(__name__)
@dataclass
class PhaseResult:
"""Summary of one pipeline phase across all episodes."""
name: str
episodes_processed: int
episodes_skipped: int
@dataclass
class PipelineRunSummary:
"""Aggregated result returned by :meth:`Executor.run`."""
phases: list[PhaseResult]
written_paths: list[Path]
validation_report: Any # ValidationReport, kept Any to avoid import cycle
@dataclass
class Executor:
"""Run all six phases over a dataset root in-process.
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
(a thread pool); cluster-level concurrency comes from running this
executor inside a Hugging Face Job. Tests construct the executor
directly with stub modules.
"""
config: AnnotationPipelineConfig
plan: Any # PlanSubtasksMemoryModule
interjections: Any # InterjectionsAndSpeechModule
vqa: Any # GeneralVqaModule
writer: LanguageColumnsWriter
validator: StagingValidator
def run(self, root: Path) -> PipelineRunSummary:
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
n = len(records)
if n == 0:
raise ValueError(f"No episodes found under {root}/data/")
print(f"[annotate] {n} episodes total", flush=True)
staging_dir = self.config.resolved_staging_dir(root)
staging_dir.mkdir(parents=True, exist_ok=True)
phases: list[PhaseResult] = []
# Phase 1: ``plan`` module (plan + subtasks + memory)
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
# Phase 2: ``interjections`` module (interjections + speech). It
# reads the ``plan`` module's subtask rows from the same staging
# tree to ground the interjection prompt in the correct local subtask.
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: ``vqa`` module (VQA)
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
print("[annotate] running validator...", flush=True)
report = self.validator.validate(records, staging_dir)
if not report.ok and not self.config.skip_validation:
raise RuntimeError(f"Staging validation failed: {report.summary()}")
print(f"[annotate] validator: {report.summary()}", flush=True)
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
written = self.writer.write_all(records, staging_dir, root)
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
# Keep meta/info.json aligned with the parquet schema we just wrote.
# Idempotent and additive: existing user metadata is preserved.
self._ensure_annotation_metadata_in_info(root)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
@staticmethod
def _ensure_annotation_metadata_in_info(root: Path) -> None:
"""Write language features and canonical tools to ``meta/info.json``.
``LanguageColumnsWriter`` adds ``language_persistent`` and
``language_events`` to parquet shards. The metadata must advertise
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
cast against the old schema and fail on the extra parquet columns.
"""
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
info_path = root / "meta" / "info.json"
if not info_path.exists():
return
try:
info = load_info(root)
except Exception as exc: # noqa: BLE001
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
return
changed = False
merged_features = {**info.features, **language_feature_info()}
if merged_features != info.features:
info.features = merged_features
changed = True
existing = info.tools or []
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
info.tools = [*existing, SAY_TOOL_SCHEMA]
changed = True
if changed:
write_info(info, root)
print(
"[annotate] meta/info.json: "
f"language_features={list(language_feature_info())}, "
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
flush=True,
)
def _run_module_phase(
self,
name: str,
records: list[EpisodeRecord],
staging_dir: Path,
module: Any,
) -> PhaseResult:
if not module.enabled:
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
n = len(records)
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
print(
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
flush=True,
)
t0 = time.time()
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
i, record = idx_record
ep_start = time.time()
staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging)
return i, record.episode_index, time.time() - ep_start
processed = 0
if parallelism == 1:
for i, record in enumerate(records, 1):
_, ep_idx, elapsed = _do((i, record))
processed += 1
print(
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
flush=True,
)
else:
with ThreadPoolExecutor(max_workers=parallelism) as pool:
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
for fut in as_completed(futures):
i, ep_idx, elapsed = fut.result()
processed += 1
print(
f"[annotate] {name} episode {processed}/{n} "
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
flush=True,
)
total = time.time() - t0
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
def _run_plan_update_phase( # noqa: PLR0915
self, records: list[EpisodeRecord], staging_dir: Path
) -> PhaseResult:
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
The ``plan`` module owns the prompt; the ``interjections`` module
produced the timestamps. This phase therefore calls back into the
``plan`` module with the interjection timestamps so its existing
prompt path is reused.
"""
if not self.plan.enabled or not self.interjections.enabled:
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
interjection_rows = [
row for row in staging.read("interjections") if row.get("style") == "interjection"
]
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
if interjection_times:
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
processed += 1
# Episodes without any interjections are skipped (no plan refresh
# needed); count them so the summary's processed+skipped == total.
return PhaseResult(
name="plan_update",
episodes_processed=processed,
episodes_skipped=len(records) - processed,
)
@@ -1,498 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keyframe extraction for the annotation pipeline.
Modules attach decoded camera frames to their VLM prompts so the model can
ground subtask decomposition, interjection scenarios, and VQA in actual
visual content. The pipeline shares one provider across modules and one
episode at a time, with a small per-episode cache so multiple modules
querying the same timestamp pay decode cost once.
"""
from __future__ import annotations
import logging
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
import PIL.Image
import torch
from lerobot.datasets.video_utils import decode_video_frames
from .reader import EpisodeRecord
logger = logging.getLogger(__name__)
class FrameProvider(Protocol):
"""Decodes camera frames at episode-relative timestamps."""
@property
def camera_keys(self) -> list[str]:
"""All ``observation.images.*`` feature keys this provider can decode."""
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
:func:`to_image_blocks` converts them to PIL only at the VLM-message
boundary.
Empty list if the camera is unavailable. ``camera_key=None`` falls back
to the provider's default camera so existing single-camera callers
(the ``plan`` and ``interjections`` modules) keep working unchanged.
"""
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
"""Return up to ``max_frames`` decoded frames covering the whole episode.
Sampling is uniform across the episode duration. Frames are
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
them into one ``{"type":"video", "video":<list>}`` block for a
Qwen-VL-compatible model that pools temporally itself. Empty list if
no camera available.
"""
@dataclass
class _NullProvider:
"""No-op provider used when the dataset has no video keys or in tests."""
@property
def camera_keys(self) -> list[str]:
return []
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
return []
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
return []
def null_provider() -> FrameProvider:
return _NullProvider()
@dataclass
class VideoFrameProvider:
"""Decodes frames from the dataset's ``observation.images.*`` streams.
By default the *first* camera key is used for the ``plan`` module
(subtask decomposition) and the ``interjections`` module (interjection
scenarios) — those prompts care about *what is happening*, not which
angle. The ``vqa`` module instead iterates over every camera in
:attr:`camera_keys` so each frame's
grounded answer (bbox/keypoint/...) is tagged with the camera it was
grounded against.
``camera_key`` overrides the default-camera choice but does not restrict
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
``video_for_episode`` to read a non-default stream.
Caches up to ``cache_size`` decoded frames per process to keep
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
"""
root: Path
camera_key: str | None = None
tolerance_s: float = 1e-2
cache_size: int = 256
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
# concurrency- and crash-safe default for the pipeline's threaded
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
# decoder when the build is known thread-safe.
video_backend: str | None = None
_meta: Any = field(default=None, init=False, repr=False)
_cache: dict = field(default_factory=dict, init=False, repr=False)
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
# one-shot warn flag against concurrent updates from worker threads.
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
def __post_init__(self) -> None:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
# Only ``video_keys`` are decodable here: the clip/decode paths read
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
# only for video-stored cameras. Image-stored cameras (also in
# ``camera_keys``) would KeyError, so restrict the list — and the
# default — to video keys.
keys = list(self._meta.video_keys)
# Last-resort fallback: if metadata didn't surface any video keys but
# the caller explicitly named a camera (``--vlm.camera_key=...``),
# trust them — the key is by definition known to exist on the dataset.
if not keys and self.camera_key:
keys = [self.camera_key]
self._camera_keys = keys
if self.camera_key is None:
self.camera_key = keys[0] if keys else None
@property
def camera_keys(self) -> list[str]:
"""All ``observation.images.*`` keys available on this dataset."""
return list(self._camera_keys)
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
target = camera_key if camera_key is not None else self.camera_key
if not timestamps or target is None:
return []
out: list[Any] = []
misses: list[float] = []
miss_indices: list[int] = []
with self._lock:
for i, ts in enumerate(timestamps):
key = (record.episode_index, target, round(float(ts), 6))
cached = self._cache.get(key)
if cached is not None:
out.append(cached)
else:
out.append(None)
misses.append(float(ts))
miss_indices.append(i)
if misses:
decoded = self._decode(record.episode_index, misses, target)
# ``_decode`` returns exactly one frame per requested timestamp,
# or an empty list if decoding failed wholesale. A partial list
# would mean a frame/timestamp misalignment, so only pair them up
# when the counts match (``strict=True`` then guards regressions).
if len(decoded) == len(miss_indices):
with self._lock:
for i, frame in zip(miss_indices, decoded, strict=True):
out[i] = frame
key = (record.episode_index, target, round(float(timestamps[i]), 6))
if len(self._cache) >= self.cache_size:
self._cache.pop(next(iter(self._cache)))
self._cache[key] = frame
# filter out any None left over from decode failures
return [frame for frame in out if frame is not None]
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
The whole episode duration is covered; the model picks subtask
boundaries from the temporal pooling it does internally. Frames are
``torch.Tensor`` (see :meth:`frames_at`).
"""
target = camera_key if camera_key is not None else self.camera_key
if max_frames <= 0 or target is None or not record.frame_timestamps:
return []
n_frames = min(max_frames, len(record.frame_timestamps))
if n_frames == len(record.frame_timestamps):
timestamps = list(record.frame_timestamps)
else:
t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1]
if t_last <= t0:
timestamps = [float(t0)] * n_frames
else:
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
timestamps = [float(t0 + i * step) for i in range(n_frames)]
return self.frames_at(record, timestamps, camera_key=target)
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
Returns ``None`` if the dataset has no video tracks. Skips
re-extract when the cached clip already exists. Re-encodes to
H.264 (libx264) so the resulting mp4 is decodable by every
downstream video processor — stream-copy would inherit the
source codec (often AV1 in modern LeRobot datasets), which
vllm's libav build cannot decode.
"""
import subprocess # noqa: PLC0415
if self.camera_key is None:
return None
cache_dir.mkdir(parents=True, exist_ok=True)
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
if out_path.exists() and out_path.stat().st_size > 0:
return out_path
ep = self._meta.episodes[record.episode_index]
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
cmd = [
"ffmpeg",
"-y",
"-loglevel",
"error",
"-ss",
f"{from_timestamp:.3f}",
"-to",
f"{to_timestamp:.3f}",
"-i",
str(src),
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-an",
str(out_path),
]
try:
# ffmpeg is invoked by name via PATH lookup (the standard way to
# call the CLI); the arg list is fully controlled here, not shell.
subprocess.run(cmd, check=True, timeout=300) # nosec B607
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
return None
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
Returns one frame per requested timestamp, or ``[]`` if decoding
failed wholesale — callers treat ``[]`` as "no frames available".
"""
ep = self._meta.episodes[episode_index]
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
# ThreadPoolExecutor and the in-process decoders are unsafe there:
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
# (a crash no try/except can catch), PyAV can likewise segfault on
# AV1, and lerobot's ``pyav`` backend routes through the removed
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
# out per frame: each decode is an isolated child process, so it is
# both crash-safe and concurrency-safe. ``video_backend`` can pin
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
# build is safe.
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
exc: Exception | None = None
for backend in chain:
try:
if backend == "ffmpeg":
return _decode_frames_ffmpeg(video_path, shifted)
if backend in ("pyav", "av"):
return _decode_frames_av(video_path, shifted)
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
decoded = decode_video_frames(
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
)
return list(decoded)
except Exception as e: # noqa: PERF203
exc = e
# Every backend raised. Log loudly the first time so a silent
# vqa-module no-op (every prompt skipped because frames_at returned
# []) is debuggable from the job log instead of post-hoc parquet
# inspection. Subsequent failures stay quiet.
with self._lock:
already_warned = self._warned_decode_fail
if not already_warned:
self._warned_decode_fail = True
if not already_warned:
logger.warning(
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backends=%s: %s",
episode_index,
camera_key,
video_path,
chain,
exc,
exc_info=exc,
)
return []
def make_frame_provider(
root: Path, camera_key: str | None = None, video_backend: str | None = None
) -> FrameProvider:
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
try:
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
except Exception:
return null_provider()
if provider.camera_key is None:
return null_provider()
return provider
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
piping a single PNG to stdout. Unlike the in-process decoders this
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
modern LeRobot datasets use) where torchcodec raises and PyAV can
SIGSEGV, and a crash stays isolated to the child process — a non-zero
exit is a catchable error, not a segfault of the whole job. Returns one
``(C, H, W)`` uint8 tensor per timestamp.
"""
import io # noqa: PLC0415
import subprocess # noqa: PLC0415
import numpy as np # noqa: PLC0415
frames: list[Any] = []
for ts in timestamps:
# ffmpeg invoked by name via PATH lookup; fully-controlled arg list, no shell.
proc = subprocess.run( # nosec B607
[
"ffmpeg",
"-nostdin",
"-loglevel",
"error",
"-ss",
f"{max(ts, 0.0):.3f}",
"-i",
str(video_path),
"-frames:v",
"1",
"-f",
"image2pipe",
"-vcodec",
"png",
"pipe:1",
],
capture_output=True,
check=True,
timeout=120,
)
if not proc.stdout:
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
return frames
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
lerobot's ``decode_video_frames(backend="pyav")`` routes through
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
fallback; this stays available behind ``video_backend="pyav"``. Returns
one ``(C, H, W)`` uint8 tensor per timestamp.
"""
import av # noqa: PLC0415
first_ts = min(timestamps)
last_ts = max(timestamps)
loaded_frames: list[torch.Tensor] = []
loaded_ts: list[float] = []
with av.open(str(video_path)) as container:
stream = container.streams.video[0]
# Seek to the keyframe at or before the first requested timestamp.
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
container.seek(offset, stream=stream, backward=True, any_frame=False)
for idx, frame in enumerate(container.decode(stream)):
ts = frame.time
if ts is None:
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
loaded_ts.append(ts)
loaded_frames.append(
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
)
if ts >= last_ts:
break
if not loaded_frames:
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
ts_tensor = torch.tensor(loaded_ts)
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
def _frame_to_pil(frame: Any) -> Any:
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
straight from :func:`decode_video_frames`); PIL is only created here, at
the VLM-message boundary, because the chat backends expect PIL images /
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
"""
if not isinstance(frame, torch.Tensor):
return frame
array = frame.detach().cpu()
if array.ndim == 3 and array.shape[0] in (1, 3):
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
if array.shape[-1] == 1:
array = array.squeeze(-1)
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
"""Wrap a list of decoded frames as one Qwen-VL video block.
Returns ``[]`` when the list is empty, so the caller can splat the result
into a content array without a separate emptiness check.
"""
if not frames:
return []
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
"""Wrap a video file URL as one ``video_url`` block.
Used by the ``openai`` backend (transformers serve / vllm serve /
ktransformers serve), where the server handles frame sampling.
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
"""
if not url:
return []
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
@@ -1,25 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .general_vqa import GeneralVqaModule
from .interjections_and_speech import InterjectionsAndSpeechModule
from .plan_subtasks_memory import PlanSubtasksMemoryModule
__all__ = [
"GeneralVqaModule",
"InterjectionsAndSpeechModule",
"PlanSubtasksMemoryModule",
]
@@ -1,248 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``vqa`` module: general VQA at a timed cadence.
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
consecutive frames, and every anchored frame gets its own VQA pair. Each
pair is grounded on that single anchor frame — there is no per-pair frame
window. For datasets with multiple cameras, every anchored frame produces
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
generated against that camera's frame and stamped with the matching
``camera`` field on the emitted rows. The resolver disambiguates via
``camera=...``; recipes that consume VQA do so through one sub-recipe
per camera (see ``recipes/pi05_hirobot.yaml``).
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
count, attribute, spatial. The assistant's ``content`` is a JSON string
whose schema depends on the question type. Malformed JSON triggers one
retry inside :meth:`VlmClient.generate_json`.
"""
from __future__ import annotations
import json
import logging
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import VqaConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..validator import classify_vqa_answer
from ..vlm_client import VlmClient
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
"""Return the relative frame indices to anchor VQA emissions to.
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
consecutive frames starting at the tick. Ticks fall on the nearest
available source frame timestamp.
"""
if hz <= 0 or k <= 0 or not frame_timestamps:
return []
t0 = frame_timestamps[0]
t_last = frame_timestamps[-1]
period = 1.0 / hz
indices: list[int] = []
t = t0
while t <= t_last + 1e-9:
# find the index of the nearest frame to t
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
for offset in range(k):
j = nearest_i + offset
if j >= len(frame_timestamps):
break
if not indices or indices[-1] != j:
indices.append(j)
t += period
# dedupe while preserving order
seen: set[int] = set()
deduped: list[int] = []
for i in indices:
if i in seen:
continue
seen.add(i)
deduped.append(i)
return deduped
@dataclass
class GeneralVqaModule:
"""Emit grounded VQA pairs at a timed cadence."""
vlm: VlmClient
config: VqaConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
_warned_no_camera: bool = field(default=False, init=False, repr=False)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
if not record.frame_timestamps:
staging.write("vqa", [])
return
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
)
cameras = self._target_cameras()
if not cameras:
# No camera available — emit nothing rather than producing
# untagged rows that would fail validation. Surface a loud one-
# time warning so this is never silently a no-op.
if not self._warned_no_camera:
logging.getLogger(__name__).warning(
"vqa module found no cameras on the frame provider — "
"every episode will emit zero VQA rows. Check that the "
"dataset declares observation.images.* features in "
"meta/info.json; passing --vlm.camera_key=<key> at the "
"CLI now also seeds the cameras list as a fallback."
)
self._warned_no_camera = True
staging.write("vqa", [])
return
# Build all messages first (one per (frame, camera)), then issue them
# as a single batched generate_json call so the client can fan them
# out concurrently.
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
for idx in anchor_idx:
ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types)
for camera in cameras:
messages = self._build_messages(record, qtype, ts, camera)
# Skip cameras that decoded to zero frames at this ts: no point
# asking the VLM to ground a bbox without an image.
if not _has_image_block(messages):
continue
per_call.append((ts, camera, qtype, messages))
if not per_call:
staging.write("vqa", [])
return
results = self.vlm.generate_json([m for _, _, _, m in per_call])
rows: list[dict[str, Any]] = []
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
qa = self._postprocess(result)
if qa is None:
continue
question, answer = qa
rows.append(
{
"role": "user",
"content": question,
"style": "vqa",
"timestamp": ts,
"camera": camera,
"tool_calls": None,
}
)
rows.append(
{
"role": "assistant",
"content": json.dumps(answer, sort_keys=True),
"style": "vqa",
"timestamp": ts,
"camera": camera,
"tool_calls": None,
}
)
staging.write("vqa", rows)
def _target_cameras(self) -> list[str]:
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
Defaults to every camera the provider exposes. Datasets with no
cameras (or test/null providers) yield an empty list, which makes
``run_episode`` a no-op.
When ``config.restrict_to_default_camera`` is set, VQA grounds on
only the provider's default camera (the single ``--vlm.camera_key``
stream), matching the plan / interjection modules so the whole
pipeline focuses on one view.
"""
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
if getattr(self.config, "restrict_to_default_camera", False):
default = getattr(self.frame_provider, "camera_key", None)
if default and default in all_cameras:
return [default]
# ``restrict_to_default_camera`` is set but the configured default
# isn't one the provider exposes. Returning it anyway would make
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
# fall through to every available camera instead.
if default:
logging.getLogger(__name__).warning(
"restrict_to_default_camera is set but camera_key=%r is not in the "
"provider's cameras %s; grounding VQA on all available cameras instead.",
default,
all_cameras,
)
return all_cameras
def _build_messages(
self,
record: EpisodeRecord,
question_type: str,
frame_timestamp: float,
camera_key: str,
) -> list[dict[str, Any]]:
prompt = load_prompt("vqa").format(
episode_task=record.episode_task,
question_type=question_type,
)
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
if not isinstance(result, dict):
return None
question = result.get("question")
answer = result.get("answer")
if not isinstance(question, str) or not question.strip():
return None
if not isinstance(answer, dict):
return None
# The validator will enforce shape; here we just sanity-check that the
# answer matches *some* known shape so we can drop garbage early.
if classify_vqa_answer(answer) is None:
return None
return question.strip(), answer
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
"""Return True if any user content block is a populated image block."""
for msg in messages:
content = msg.get("content")
if not isinstance(content, list):
continue
for block in content:
if isinstance(block, dict) and block.get("type") == "image":
return True
return False
@@ -1,211 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
Two sub-passes:
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
canonical task). No interjection row — the canonical task is already the
user utterance from ``meta/tasks.parquet``.
2. For mid-episode interruptions, emit a co-timestamped pair:
{role:user, style:interjection, content:<text>}
speech atom (role:assistant, style:None, tool_calls=[say(...)])
Both rows go in ``language_events`` at the same timestamp.
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
interjection timestamps to refresh the ``plan`` row at the same instant.
"""
from __future__ import annotations
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import InterjectionsConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
from ..writer import speech_atom
@dataclass
class InterjectionsAndSpeechModule:
"""Generate task-start speech and mid-episode interjection/speech pairs."""
vlm: VlmClient
config: InterjectionsConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
if record.frame_timestamps:
t0 = float(record.frame_timestamps[0])
initial = self._initial_speech(record)
if initial:
rows.append(speech_atom(t0, initial))
# Pull the ``plan`` module's subtask spans for this episode so the
# interjection prompt can ground itself in the actual current
# subtask at each chosen timestamp. The ``plan`` module ran first.
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
rows.extend(self._mid_episode_interjections(record, subtask_spans))
staging.write("interjections", rows)
@staticmethod
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
current: str | None = None
for span in spans:
if float(span["start"]) <= t:
current = span.get("text")
else:
break
return current
def _initial_speech(self, record: EpisodeRecord) -> str | None:
prompt = load_prompt("interjections_initial_speech").format(
episode_task=record.episode_task,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("text"), str):
text = result["text"].strip()
if text:
return text
return None
def _mid_episode_interjections(
self,
record: EpisodeRecord,
subtask_spans: Sequence[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Generate interjections aligned with the actual demo trajectory.
Teleop data is frozen — the robot already executed every step in
the video. A *counterfactual* interjection like "actually skip
the wipe" contradicts what then happens in the video, which is
what qwen36moe-10/11 surfaced as low-quality interjections.
Instead, anchor every interjection at a subtask boundary and
write it as a natural user request for the *upcoming* subtask.
The robot's visible next behavior IS the interjection's effect,
so the training signal stays consistent: interjection text →
plan refresh → action stream all line up.
"""
if self.config.max_interjections_per_episode <= 0:
return []
if len(subtask_spans) < 2:
# Need at least one transition (subtask 0 → subtask 1).
return []
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
# Boundaries: the start time of every subtask except the first
# (which is just t0 and is covered by the initial-task speech atom).
boundaries: list[tuple[float, str, str]] = []
for i in range(1, len(subtask_spans)):
ts = float(subtask_spans[i]["start"])
if ts < self.config.interjection_min_t:
continue
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
next_text = (subtask_spans[i].get("text") or "").strip()
if not next_text:
continue
boundaries.append((ts, prev_text, next_text))
if not boundaries:
return []
n = min(self.config.max_interjections_per_episode, len(boundaries))
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
out: list[dict[str, Any]] = []
for t, prev_subtask, next_subtask in chosen:
t_snap = snap_to_frame(t, record.frame_timestamps)
# Window straddles the boundary so the VLM sees the end of the
# previous subtask and the start of the next one — same
# conditioning the policy will see at training time.
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
prompt = load_prompt("interjections_interjection").format(
episode_task=record.episode_task,
prev_subtask=prev_subtask or "(starting from initial state)",
next_subtask=next_subtask,
timestamp=t_snap,
window_seconds=self.config.interjection_window_seconds,
)
images = self.frame_provider.frames_at(record, window_ts)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
continue
interjection_text = result.get("interjection")
speech_text = result.get("speech")
if not isinstance(interjection_text, str) or not interjection_text.strip():
continue
if not isinstance(speech_text, str) or not speech_text.strip():
continue
out.append(
{
"role": "user",
"content": interjection_text.strip(),
"style": "interjection",
"timestamp": t_snap,
"tool_calls": None,
}
)
out.append(speech_atom(t_snap, speech_text.strip()))
return out
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
"""Return a small set of frame timestamps centered on ``t_anchor``.
The window straddles the subtask boundary the interjection sits
on: roughly half the frames cover the end of the previous
subtask, half cover the start of the next one. The VLM therefore
sees BOTH what just finished AND what's about to start, which is
the conditioning we need to write a natural "now please do X"
request that matches the visible upcoming behavior.
"""
if not frame_timestamps:
return [t_anchor]
n = max(1, int(self.config.interjection_window_frames))
if n == 1:
return [t_anchor]
window = float(self.config.interjection_window_seconds)
step = window / max(1, n - 1)
# Center the window on the anchor so half lands before, half after.
start_offset = -window / 2.0
targets = [t_anchor + start_offset + step * i for i in range(n)]
first_ts = float(frame_timestamps[0])
last_ts = float(frame_timestamps[-1])
snapped: list[float] = []
seen: set[float] = set()
for tgt in targets:
clamped = min(last_ts, max(first_ts, tgt))
t = snap_to_frame(clamped, frame_timestamps)
if t not in seen:
seen.add(t)
snapped.append(t)
return snapped or [t_anchor]
@@ -1,712 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from ..config import PlanConfig
from ..frames import (
FrameProvider,
VideoFrameProvider,
null_provider,
to_video_block,
to_video_url_block,
)
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
logger = logging.getLogger(__name__)
@dataclass
class PlanSubtasksMemoryModule:
"""Generate subtask spans, plan, and memory rows.
All output is persistent (lives in ``language_persistent``):
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
(snapped to an exact frame).
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
timestamp via :meth:`run_plan_updates` (called by the executor after
the ``interjections`` module completes).
- ``memory`` rows: emitted at each subtask boundary (= subtask start
timestamp from the second subtask onward).
"""
vlm: VlmClient
config: PlanConfig
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
# Task driving every plan-module prompt: canonical episode_task, or a
# video-derived one when it's empty/placeholder (see derive_task_*).
effective_task = self._resolve_effective_task(record)
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
# free-form n_task_rephrasings; the effective task is always emitted
# first so the rotation covers the source-of-truth phrasing.
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
variants: list[str] | None = None
if self.config.task_aug_axes.enabled and effective_task:
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
elif self.config.n_task_rephrasings > 0 and effective_task:
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
if variants is not None:
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
subtask_spans = self._generate_subtasks(record, task=effective_task)
# subtask rows
for span in subtask_spans:
rows.append(
{
"role": "assistant",
"content": span["text"],
"style": "subtask",
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
"tool_calls": None,
}
)
# Plan rows at every subtask boundary (incl. t=0). The plan is a
# numbered list of still-todo subtasks, so re-emitting at each
# boundary makes it shrink as work progresses — ${plan} at frame t is
# exactly what's left to do.
if self.config.emit_plan:
for span in subtask_spans:
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
plan_text = self._generate_plan(
record, subtask_spans, refresh_t=boundary_t, task=effective_task
)
if plan_text is not None:
rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": float(boundary_t),
"tool_calls": None,
}
)
# memory rows at every subtask boundary except the very first start
prior_memory = ""
for i, span in enumerate(subtask_spans[1:], start=1):
completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
if mem_text:
ts = snap_to_frame(span["start"], record.frame_timestamps)
rows.append(
{
"role": "assistant",
"content": mem_text,
"style": "memory",
"timestamp": ts,
"tool_calls": None,
}
)
prior_memory = mem_text
staging.write("plan", rows)
# ------------------------------------------------------------------
# Task derivation + rephrasings
# ------------------------------------------------------------------
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
{
"debug",
"test",
"tbd",
"todo",
"n/a",
"na",
"untitled",
"unnamed",
"default",
"placeholder",
}
)
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
"""Decide which task string drives the ``plan`` module for this episode.
Returns the user-supplied ``record.episode_task`` unless
``derive_task_from_video`` says otherwise (see config docstring).
Falls back gracefully to the canonical task if video derivation
fails.
"""
canonical = (record.episode_task or "").strip()
mode = (self.config.derive_task_from_video or "off").strip().lower()
if mode == "always":
derived = self._derive_task_from_video(record)
return derived or canonical
if mode == "if_short" and self._task_seems_bad(canonical):
derived = self._derive_task_from_video(record)
if derived:
return derived
return canonical
def _task_seems_bad(self, task: str) -> bool:
if not task:
return True
if len(task.split()) < int(self.config.derive_task_min_words):
return True
return task.lower() in self._PLACEHOLDER_TASKS
@staticmethod
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
seen: set[str] = set()
rows: list[dict[str, Any]] = []
for phrasing in phrasings:
key = phrasing.strip()
if not key or key in seen:
continue
seen.add(key)
rows.append(
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
)
return rows
# ------------------------------------------------------------------
# VLM call helpers — every plan-module prompt follows the same shape:
# build messages → single VLM call → pull a named field.
# ------------------------------------------------------------------
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
"""Run a single VLM call and return ``result[field]`` or ``None``.
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
dance every prompt-call site needs.
"""
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict):
return result.get(field)
return None
@staticmethod
def _text_message(text: str) -> list[dict[str, Any]]:
"""One-shot text-only user message wrapped for ``generate_json``."""
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
def _video_message(
self,
record: EpisodeRecord,
prompt: str,
window: tuple[float, float] | None = None,
) -> list[dict[str, Any]]:
"""User message combining the (optionally windowed) video block with ``prompt``."""
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
"""Ask the VLM "what is this video about" with no task hint at all."""
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
return text.strip() if isinstance(text, str) and text.strip() else None
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
"""Generate ``n`` text-only paraphrases of ``base_task``."""
if n <= 0 or not base_task:
return []
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
if not isinstance(raw, list):
return []
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
return [s for s in out if s][:n]
# ------------------------------------------------------------------
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
# ------------------------------------------------------------------
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
"""One VLM call → variants along the 5-axis taxonomy.
Variants from all axes are flattened into a single list (the
downstream pipeline doesn't need to know about the per-axis
bucketing — every variant becomes a ``task_aug`` row). Order
is preserved for reproducibility: synonym_paraphrase first,
then omit_arm, then omit_orientation, then omit_grasp_method,
then combined_omissions.
"""
if not base_task:
return []
prompt = load_prompt("plan_task_aug_axes").format(
base_task=base_task,
n_synonym=axes_cfg.synonym_paraphrase,
n_omit_arm=axes_cfg.omit_arm,
n_omit_orientation=axes_cfg.omit_orientation,
n_omit_grasp_method=axes_cfg.omit_grasp_method,
n_combined=axes_cfg.combined_omissions,
)
result = self.vlm.generate_json([self._text_message(prompt)])[0]
if not isinstance(result, dict):
return []
ordered_axes = (
"synonym_paraphrase",
"omit_arm",
"omit_orientation",
"omit_grasp_method",
"combined_omissions",
)
flat: list[str] = []
seen: set[str] = set()
for axis in ordered_axes:
entries = result.get(axis)
if not isinstance(entries, list):
continue
for item in entries:
if not isinstance(item, str):
continue
key = item.strip().strip('"').strip("'")
if not key or key in seen:
continue
seen.add(key)
flat.append(key)
return flat
def _episode_video_block(
self, record: EpisodeRecord, window: tuple[float, float] | None = None
) -> list[dict[str, Any]]:
"""Video block for the segmentation / describe prompts.
Always returns a block that actually carries the video. When
``use_video_url`` is set we try the server-side ``video_url``
path first, but if clip extraction fails we FALL BACK to
decoding + embedding frames rather than returning an empty
block — an empty block would leave the VLM with no visual
grounding at all and it would hallucinate subtasks purely from
the task text.
When ``window=(w0, w1)`` is given (windowed subtask generation,
``subtask_window_seconds > 0``), embed frames sampled at the FIXED
``frames_per_second`` rate within ``[w0, w1]`` — constant temporal
density regardless of episode length, so long episodes are split
into windows rather than subsampled to a sparse 32-frame whole-
episode view. The ``video_url`` path is skipped for windows (it is
a whole-episode clip). ``max_video_frames`` still caps each window
as a context-budget safety net.
"""
if not record.frame_timestamps:
return []
if window is not None:
w0, w1 = float(window[0]), float(window[1])
dur = max(0.0, w1 - w0)
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_video_frames)
if n <= 1 or dur <= 0.0:
timestamps = [0.5 * (w0 + w1)]
else:
step = dur / (n - 1)
timestamps = [w0 + i * step for i in range(n)]
return to_video_block(self.frame_provider.frames_at(record, timestamps))
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
clip = self.frame_provider.episode_clip_path(record, cache_dir)
if clip is not None:
return to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
logger.warning(
"episode %d: video_url clip extraction failed — falling back to "
"embedded frames so the VLM still sees the demonstration",
record.episode_index,
)
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
target_count = min(target_count, self.config.max_video_frames)
video_frames = self.frame_provider.video_for_episode(record, target_count)
return to_video_block(video_frames)
def run_plan_updates(
self,
record: EpisodeRecord,
staging: EpisodeStaging,
interjection_times: Sequence[float],
interjection_texts: Sequence[str] | None = None,
) -> None:
"""Append additional ``plan`` rows at every interjection timestamp.
Plans refresh ONLY on user interjections (event-driven). The
interjection text is forwarded into the prompt so the refreshed plan
reflects the user's correction.
"""
if not self.config.emit_plan:
return
existing = staging.read("plan")
# Pass the last frame timestamp so the final span is closed (else its
# end == start, zero duration, and a refresh inside it is missed).
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
new_rows = list(existing)
texts: list[str | None] = (
[None] * len(interjection_times)
if interjection_texts is None
else [str(t) if t else None for t in interjection_texts]
)
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
t = snap_to_frame(raw_t, record.frame_timestamps)
if t in already_planned:
continue
already_planned.add(t)
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
if plan_text is not None:
new_rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": t,
"tool_calls": None,
}
)
staging.write("plan", new_rows)
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
"""Generate subtask spans, optionally via a multi-call quality chain.
Single call (default): watch video → emit subtask JSON.
Multi-call (opt-in, higher quality, more VLM calls):
1. ``subtask_describe_first`` — a grounding pass that narrates
ONLY what is visible (no JSON commitment to subtasks yet);
its description is injected into the segmentation prompt so
the model segments its own grounded observations instead of
pattern-matching the task text.
2. segmentation — emit subtask JSON (as before).
"""
if record.row_count == 0 or not record.frame_timestamps:
return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
effective_task = task if task is not None else record.episode_task
# ---- Windowed path (constant temporal density) ---------------
# If subtask_window_seconds > 0 and the episode exceeds one window,
# process fixed-length windows so the VLM always sees
# frames_per_second density; results are merged + stitched.
window_s = float(getattr(self.config, "subtask_window_seconds", 0.0) or 0.0)
if window_s > 0.0 and episode_duration > window_s:
return self._generate_subtasks_windowed(record, effective_task, window_s)
# ---- Pass 1 (optional): grounding description ----------------
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
description = self._describe_episode(record, effective_task)
if description:
observation_block = (
"You watched this video and described, chronologically, "
"ONLY what the robot actually does:\n"
f'"""{description}"""\n\n'
"Segment THAT grounded description (cross-checked against "
"the video) into atomic subtasks. Do not introduce any "
"action that is not in your description above.\n\n"
)
# ---- Pass 2: segmentation ------------------------------------
prompt = load_prompt("plan_subtasks").format(
episode_task=effective_task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
observation_block=observation_block,
)
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
cleaned = self._clean_spans(spans, record)
if not cleaned:
return []
# ---- Full-episode coverage stitch ----------------------------
# The VLM can start after t0 or leave gaps, so frames fall through
# with no active subtask. Always stitch into a contiguous
# [t0, t_last] cover.
cleaned = self._stitch_full_coverage(cleaned, record)
return cleaned
def _generate_subtasks_windowed(
self, record: EpisodeRecord, task: str, window_s: float
) -> list[dict[str, Any]]:
"""Subtask generation in fixed-length windows at constant fps.
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
seconds, runs the describe -> segment chain on each window's own
frames (sampled at ``frames_per_second``), offsets
each window's spans back to absolute episode time, then merges +
stitches into a contiguous whole-episode cover.
"""
t0 = float(record.frame_timestamps[0])
t_last = float(record.frame_timestamps[-1])
all_spans: list[dict[str, Any]] = []
w0 = t0
n_windows = 0
while w0 < t_last - 1e-6:
w1 = min(w0 + window_s, t_last)
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
n_windows += 1
w0 = w1
logger.info(
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
record.episode_index,
n_windows,
window_s,
len(all_spans),
)
# Merge across windows: clamp to the absolute episode, sort, and
# frame-snap to distinct starts (handles any boundary collisions).
cleaned = self._clean_spans(all_spans, record)
if not cleaned:
return []
return self._stitch_full_coverage(cleaned, record)
def _subtasks_for_window(
self, record: EpisodeRecord, task: str, w0: float, w1: float
) -> list[dict[str, Any]]:
"""Run describe -> segment on one ``[w0, w1]`` window.
The model works in window-RELATIVE time ``[0, L]`` (it perceives
the window as a clip starting at 0); spans are offset back to
absolute ``[w0, w1]`` before returning.
"""
window = (w0, w1)
win_len = max(0.0, w1 - w0)
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
description = self._describe_episode(record, task, window=window)
if description:
observation_block = (
"You watched this video clip and described, chronologically, "
"ONLY what the robot actually does:\n"
f'"""{description}"""\n\n'
"Segment THAT grounded description (cross-checked against "
"the clip) into atomic subtasks. Do not introduce any "
"action that is not in your description above.\n\n"
)
prompt = load_prompt("plan_subtasks").format(
episode_task=task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{win_len:.3f}",
observation_block=observation_block,
)
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
# Window-relative clamp; no frame-snap dedupe yet (done on the
# merged absolute set).
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
if not cleaned:
return []
# Offset window-relative spans back to absolute episode time.
for s in cleaned:
s["start"] = w0 + float(s["start"])
s["end"] = w0 + float(s["end"])
return cleaned
def _stitch_full_coverage(
self, spans: list[dict[str, Any]], record: EpisodeRecord
) -> list[dict[str, Any]]:
"""Make subtask spans tile the full episode with no gaps.
* The first subtask starts at the episode's first frame ``t0``
(any idle / approach before the first labelled action is folded
into it), so every early frame has an active subtask.
* Each subtask's ``end`` is snapped to the next subtask's
``start`` (gaps between spans are closed), and the final
subtask's ``end`` extends to the last frame ``t_last``.
Starts are otherwise left as the (already frame-snapped, distinct)
values the VLM produced — only the FIRST start is pulled
back to ``t0``, which can't collide with a later span because it
was already the earliest. Purely deterministic; runs after the
VLM passes.
"""
if not spans or not record.frame_timestamps:
return spans
t0 = float(record.frame_timestamps[0])
t_last = float(record.frame_timestamps[-1])
spans = sorted(spans, key=lambda s: float(s["start"]))
spans[0]["start"] = t0
for i in range(len(spans) - 1):
spans[i]["end"] = float(spans[i + 1]["start"])
spans[-1]["end"] = t_last
for s in spans:
if float(s["end"]) < float(s["start"]):
s["end"] = float(s["start"])
return spans
def _clean_spans(
self,
spans: Any,
record: EpisodeRecord,
bounds: tuple[float, float] | None = None,
dedupe: bool = True,
) -> list[dict[str, Any]]:
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
``bounds`` overrides the clamp range — pass the window's
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
``None`` to clamp to the whole episode ``[t0, t_last]``.
``dedupe`` runs the frame-snap distinct-start step; skip it for
window-relative spans (frame snapping is done once on the merged,
absolute-time set).
"""
if not spans:
return []
if bounds is not None:
lo, hi = float(bounds[0]), float(bounds[1])
else:
lo = record.frame_timestamps[0]
hi = record.frame_timestamps[-1]
cleaned: list[dict[str, Any]] = []
for span in spans:
try:
start = float(span["start"])
end = float(span["end"])
text = str(span["text"]).strip()
except (KeyError, ValueError, TypeError):
continue
start = max(lo, min(start, hi))
end = max(lo, min(end, hi))
if end < start:
start, end = end, start
if not text:
continue
cleaned.append({"text": text, "start": start, "end": end})
cleaned.sort(key=lambda s: s["start"])
if dedupe:
return self._dedupe_starts_to_distinct_frames(cleaned, record)
return cleaned
def _describe_episode(
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
) -> str:
"""Grounding pass: free-form chronological description of the (windowed) video."""
prompt = load_prompt("plan_subtask_describe").format(episode_task=task)
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
return text.strip() if isinstance(text, str) and text.strip() else ""
@staticmethod
def _dedupe_starts_to_distinct_frames(
spans: list[dict[str, Any]], record: EpisodeRecord
) -> list[dict[str, Any]]:
"""Bump same-frame subtask starts onto distinct frames.
Two consecutive VLM spans whose ``start`` rounds to the same
source frame (after :func:`snap_to_frame`) would otherwise emit
two ``style=subtask`` rows at the identical persistent
timestamp. The training-time renderer's ``active_at(t,
style=subtask)`` resolver can't disambiguate that and raises
``Ambiguous resolver for style='subtask'``.
Walk the (sorted-by-start) spans, snap each to its frame, and
if the snapped frame is already taken push the span onto the
next unused frame so both subtasks survive on distinct
timestamps. If the episode ends before a free frame is found,
the trailing span is dropped with a warning — better than
poisoning the render.
"""
if not spans:
return spans
frames = record.frame_timestamps
if not frames:
return spans
used: set[float] = set()
out: list[dict[str, Any]] = []
for span in spans:
ts = snap_to_frame(span["start"], frames)
if ts in used:
next_ts = next((f for f in frames if f > ts and f not in used), None)
if next_ts is None:
logger.warning(
"episode %d: subtask %r snapped to occupied frame "
"%.3f and no free later frame exists — dropping",
record.episode_index,
span.get("text"),
ts,
)
continue
ts = next_ts
used.add(ts)
new_span = {**span, "start": ts}
if float(new_span.get("end", ts)) < ts:
new_span["end"] = ts
out.append(new_span)
return out
def _generate_plan(
self,
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
subtask_spans: Sequence[dict[str, Any]],
*,
refresh_t: float | None = None,
interjection: str | None = None, # noqa: ARG002
task: str | None = None, # noqa: ARG002
) -> str | None:
"""Deterministic plan = numbered list of *still-todo* subtasks.
No VLM call: a plain numbered list keeps the plan aligned with the
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
cost a round-trip per episode/refresh and could diverge).
1. <subtask 1>
2. <subtask 2>
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
interjections, and ``run_episode`` at each boundary), only subtasks
starting at or after ``refresh_t`` are included — so it always
describes what's left.
"""
if not subtask_spans:
return None
remaining = [
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
]
if not remaining:
# Past the last subtask boundary on a late refresh — nothing
# left to plan; emit None so the caller skips the row.
return None
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
def _generate_memory(
self,
record: EpisodeRecord,
prior_memory: str,
completed: str,
remaining: Sequence[str],
*,
task: str | None = None,
) -> str:
prompt = load_prompt("plan_memory").format(
episode_task=(task if task is not None else record.episode_task),
prior_memory=prior_memory or "(none)",
completed_subtask=completed,
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
)
memory = self._vlm_field(self._text_message(prompt), "memory")
return memory.strip() if isinstance(memory, str) else ""
@@ -1,33 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompt templates loaded as plain text.
One file per use site. Templates use ``str.format(**vars)`` substitution; we
intentionally avoid jinja2 here so the templates remain inspectable in
plain editors and roundtrip cleanly through ``ruff format``.
"""
from __future__ import annotations
from pathlib import Path
_DIR = Path(__file__).parent
def load(name: str) -> str:
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
path = _DIR / f"{name}.txt"
return path.read_text(encoding="utf-8")
@@ -1,12 +0,0 @@
The user just asked the robot: "{episode_task}".
Generate a short verbal acknowledgement the robot would speak back before
beginning the task. Style: compact, confident, friendly.
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
"OK, starting with the sponge.", "Got it.".
Prefer very short replies: "Got it.", "On it.", "OK."
Output strictly valid JSON:
{{ "text": "<the spoken acknowledgement>" }}
@@ -1,46 +0,0 @@
You are generating training data for a Hi Robot-style hierarchical
robot policy. The robot in this demonstration has ALREADY executed
every step shown in the video — we cannot retroactively change the
action stream. To keep training data consistent with the video, the
"interjection" must align with what the robot is *about to do next* in
the demonstration, framed as a natural mid-task user request.
The episode's overall task: "{episode_task}".
The images above show roughly {window_seconds:.1f} seconds straddling a
subtask boundary in the demonstration:
- Subtask the robot just finished: "{prev_subtask}"
- Subtask the robot is about to start: "{next_subtask}"
- Time into episode: {timestamp:.2f}s
Write ONE compact interjection the user would naturally say at this
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
Keep it like a mid-task coaching cue, not a full instruction paragraph.
Also write the robot's compact verbal acknowledgement.
Hard rules:
- The interjection MUST be consistent with the next subtask. The user
cannot ask for something different from what the robot then does in
the video. If you're tempted to say "actually skip X" or "do Y
instead", DO NOT — those would contradict the demonstration.
- The interjection must reference an object, location, or action that
is plausible given the visible scene and the next subtask text.
- One short phrase or sentence each. Conversational, not robotic.
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
- Keep robot speech very short: "OK.", "On it.", "Doing that."
Style examples (vary the phrasing — don't reuse these verbatim):
- "Now go ahead and {next_subtask}."
- "Great, can you {next_subtask} next?"
- "{next_subtask}, please."
- "Before you continue, please {next_subtask}."
- "Looking good — {next_subtask} now."
- "Okay, {next_subtask}."
Output strictly valid JSON:
{{
"interjection": "<short cue from the user, asking for the next subtask>",
"speech": "<short robot acknowledgement>"
}}
@@ -1,36 +0,0 @@
You are updating the robot's compressed semantic memory at the boundary of
a completed subtask.
Reference (verbatim from MEM, Torne 2026):
"Remove or compress information in the language memory whenever
appropriate. Keep ONLY the minimal set of relevant information for future
task execution. Specific object attributes (colors, precise quantities of
each item) get discarded when their details won't affect subsequent
actions. Functional outcomes (where items went, how many) are preserved."
Episode task: "{episode_task}"
Previous memory: {prior_memory}
Just-completed subtask: "{completed_subtask}"
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
robot has accomplished so far — the running story it would tell itself.
Authoring rules:
- First person, past tense. Every sentence starts with "I": "I picked
up...", "I opened...", "I moved to...".
- One or two short sentences. Extend the previous memory with the
just-completed subtask; do not rewrite it from scratch.
- Keep WHAT happened (functional outcomes — where items went, how many),
drop HOW (grasp details, motions).
- Compress completed steps and drop object attributes (colors, exact
counts) once they no longer affect the remaining subtasks.
Example (MEM, Torne 2026):
Before: "I prepared the pot and got the potatoes, milk, and butter. I
moved to the drawer."
After: "I prepared the pot and got the ingredients. I opened the
drawer with the masher."
Output strictly valid JSON:
{{ "memory": "<one or two short first-person past-tense sentences>" }}
@@ -1,27 +0,0 @@
You are watching a teleoperated robot demonstration from a single
camera. The user asked the robot to: "{episode_task}"
This is an OBSERVATION pass. Watch the entire clip and describe, in
chronological order, ONLY what the robot physically does — the concrete
motions, approaches, contacts, grasps, releases, and relocations you can
actually SEE in the frames.
Hard rules:
- Describe only motion visible in the video. Do NOT use the task
instruction to guess steps that aren't shown. The instruction is the
goal; the video is ground truth.
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
the single field below. Just narrate what happens.
- Give an approximate timestamp (in seconds) for each distinct event,
e.g. "0.0-1.4s: the base drives forward toward the stove".
- Do NOT invent objects, grasps, destinations, or steps. If the robot
only does one thing (e.g. it just navigates and the clip ends), say
exactly that and nothing more.
- Be concrete and literal. "the gripper closes on the mug" — not "the
robot prepares to make coffee".
Output strictly valid JSON:
{{
"description": "<chronological, timestamped description of ONLY what is visible>"
}}
@@ -1,112 +0,0 @@
You are labeling a teleoperated robot demonstration.
The user originally asked: "{episode_task}"
You are shown the entire demonstration as a single video. Watch the
whole clip, then segment it into a list of consecutive atomic subtasks
the robot performs.
{observation_block}GROUNDING — read this first, it overrides everything below:
- Label ONLY what the robot actually does in the video. Every subtask
you emit must correspond to motion you can SEE in specific frames.
- Do NOT invent, anticipate, or pad. If the robot only does one thing
(e.g. it just navigates to a location and the clip ends), emit
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
subtasks than the ceiling is not just allowed, it is expected for
short / atomic demonstrations. One correct subtask is far better
than several invented ones.
- If the video does not clearly show the action implied by the task,
describe what you actually see — do NOT fabricate the task's steps
from the instruction text. The instruction tells you the goal; the
VIDEO is the ground truth for what happened.
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
- Each subtask = one COMPOSITE atomic skill the low-level policy can
execute end-to-end. A "skill" bundles its own approach motion with
its terminal action — do NOT split the approach off as its own
subtask. The whole-arm policy already learns to reach as part of
every manipulation primitive.
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
these verbs (extend only when none fits):
pick up <obj> — approach + grasp + lift in one subtask
put <obj> on/in <loc> — transport + release in one subtask
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
push <obj> — contact + linear shove
pull <obj> — contact + linear retract
turn <knob/dial/handle> — rotary actuation
press <button> — single-press contact
open <drawer/door/lid> — full open motion
close <drawer/door/lid> — full close motion
pour <src> into <dst> — tilt + flow
insert <obj> into <slot>— alignment + push-fit
go to <loc> — ONLY when no grasp / actuation follows
(e.g. a pure relocation between phases).
If the next subtask grasps something at
that location, drop "go to ..." and just
write "pick up ..." instead.
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
as standalone subtasks; fold them into the parent composite:
"move to X" → fold into "pick up X" (or whatever follows)
"reach for X" → fold into "pick up X"
"grasp X" → fold into "pick up X"
"lift X" → fold into "pick up X" (or "put X on Y" if it's
the transport phase of a place)
"release X" → fold into "put X on Y" (or "place X in Y")
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
detail (which hand, which grasp point) ONLY when it is needed to
disambiguate. Every subtask must begin with one of the verbs
above (no leading nouns, no "then", no "first").
- NEVER use third person. Never write "the robot", "the arm", "the
gripper moves", "it picks up" — the robot is implied. Command it,
do not describe it.
- Use the exact object nouns from the task above. If the task says
"cube", every subtask says "cube" — never switch to "block". If it
says "box", never switch to "bin"/"container". Keep vocabulary
consistent across the whole episode.
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
"turn red knob", "press start button", "go to sink".
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
must be folded into "pick up blue cube"); "the robot arm moves
towards the blue cube" (third person, too long); "carefully pick
up the cube" (adverb, article); "release the yellow block"
("block" when the task said "cube", and "release" must be folded
into a "put"/"place" subtask).
- Subtasks are non-overlapping and cover the full episode in order.
Choose the cut points yourself based on what you see in the video
(gripper open/close events, contact, regrasps, transitions).
- Each subtask spans at least {min_subtask_seconds} seconds. If a
candidate span would be shorter, merge it into its neighbour
rather than emitting it.
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
are preferred over many micro-steps.
- Every subtask's [start_time, end_time] must lie within
[0.0, {episode_duration}] seconds.
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
fires ONLY on the spatial situation it names; it must not change how you
label any other situation):
- STACK vs PUT: if an object is placed ON TOP OF another specific object
(not on a flat table / shelf / counter), use "stack ... on ...", not
"put". "stack blue book on green book", NOT "put blue book on table".
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
receptacle (push-fit), use "insert ... into ...", not "put".
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
on the object and the object moves WITH the hand, it is "pick up" /
"retrieve" (object leaves its location). If the gripper OPENS and the
object stays where the hand left it, it is "put" / "place" (object
arrives at a location). Decide by which way the object moves, not by
where the hand ends up.
- POUR vs PUT: only use "pour" when the source is tilted and contents
flow out; moving a full container without tilting is "put"/"place".
Output strictly valid JSON of shape:
{{
"subtasks": [
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
...
]
}}
@@ -1,67 +0,0 @@
You are generating structured augmentations of a robot task instruction
for training a language-conditioned policy. Unlike free-form rephrasing,
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
a specific element of the task while preserving its meaning.
Original task: "{base_task}"
Produce variants along five named axes. Each axis has a target count.
The whole batch should expose the policy to maximum linguistic diversity
WITHOUT changing what the robot is supposed to do.
Axes and target counts:
synonym_paraphrase ({n_synonym}):
Different wording / verbs / sentence structure. ALL information
from the original task is preserved — same object, same arm
specification if present, same orientation if present, same grasp
if present.
omit_arm ({n_omit_arm}):
Drop the left/right/both arm specification from the task. Skip
entirely (emit 0 entries) if the original task does NOT mention an
arm. Do not invent an arm specification just to omit it.
omit_orientation ({n_omit_orientation}):
Drop orientation cues (upright, sideways, facing the user,
long-edge-first, etc.). Skip entirely if no orientation cue is
present in the original task.
omit_grasp_method ({n_omit_grasp_method}):
Drop the grip / grasp method specification (pinch, wrap, hold by
the rim, etc.). Skip entirely if no grasp method is mentioned.
combined_omissions ({n_combined}):
Combine TWO of the above omissions simultaneously (e.g. drop both
arm and orientation). Skip entirely if fewer than two of (arm,
orientation, grasp_method) appear in the original task.
Hard rules:
- Each variant MUST preserve the core action, the target object, AND
the goal / destination. Do not change which object is involved, where
it goes, or the high-level action. "Navigate to the stove" may become
"go to the stove" or "head over to the stove" — it must NEVER become
"wander around the kitchen", "explore the room", or anything that
drops or generalises the stove destination. If you cannot vary the
wording without changing the goal, emit fewer variants.
- Only the FIVE listed elements (wording, arm, orientation, grasp
method, or a combination) may be varied or omitted. The verb's
meaning, the object, and the destination are fixed.
- Each variant is plain prose, no markdown, no quotes, no list numbers.
- Each variant must be DISTINCT from every other variant in the entire
output, both within and across axes. Near-duplicates are not allowed.
- If an axis cannot reach its target count because the original task
lacks the omittable element, emit fewer entries — do NOT pad the
axis with paraphrases that belong to a different axis.
- Variants should not all start with verbs — vary sentence structure
(some imperative, some polite request, some question).
Output strictly valid JSON of shape:
{{
"synonym_paraphrase": ["<v1>", "<v2>", ...],
"omit_arm": ["<v1>", "<v2>", ...],
"omit_orientation": ["<v1>", ...],
"omit_grasp_method": ["<v1>", ...],
"combined_omissions": ["<v1>", ...]
}}
@@ -1,32 +0,0 @@
You are generating training data for a Hi Robot-style policy. We need
{n} alternative phrasings of the same robot task so the policy sees
diverse user prompts during training instead of the same canonical
string repeated every frame.
Original task:
"{base_task}"
Generate exactly {n} alternative phrasings of the same task. Vary:
- formality (casual / polite / curt)
- verbosity (mostly short imperative; occasional polite request)
- word choice (synonyms, different verbs)
- sentence structure (imperative / question / suggestion)
Hard rules:
- Each phrasing MUST preserve the exact meaning of the original task.
Do not change which object is involved, the destination, or the
action. Do not add extra steps. Do not invent new objects.
- Each phrasing must be a short phrase or sentence, plain prose, no
markdown, no quotes, no list numbers.
- Phrasings must be distinct — no near-duplicates.
- Output exactly {n} entries.
Output strictly valid JSON:
{{
"rephrasings": [
"<phrasing 1>",
"<phrasing 2>",
...
]
}}
@@ -1,17 +0,0 @@
The video above shows a robot manipulation episode in full. Look at
the entire video and describe in ONE concise sentence what the robot
is doing.
Rules:
- One sentence, in natural English, like a user instruction.
- Capture the goal of the demonstration, not low-level motions.
Example: "place the yellow cube into the red bin" — not "move the
end-effector down 5cm and close the gripper".
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
- Do not invent objects or actions that aren't visible.
- Do not output anything other than the JSON object below.
Output strictly valid JSON:
{{
"task": "<single concise sentence describing what the robot does in this video>"
}}
@@ -1,32 +0,0 @@
You are generating a frame-grounded visual question/answer pair for
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
Policies — both train policies on grounded features such as bounding box
pixel coordinates, keypoints, counts, attributes, and spatial relations.
The frame shows a robot working on: "{episode_task}".
Question types and the EXACT answer JSON shape required for each:
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}}, ...]}}
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
ECoT example: "a white cup [124, 25, 176, 113]".
keypoint => {{"label": "<point>", "point_format": "xy",
"point": [x, y]}}
count => {{"label": "<obj>", "count": <int>,
"note": "<optional short note>"}}
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
"value": "<observed value>"}}
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
"above|below|near>", "object": "<obj>"}}
Generate a question of type "{question_type}". Output strictly valid JSON:
{{
"question": "<short, frame-grounded question>",
"answer": <object whose shape matches the schema above>
}}
@@ -1,216 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datatrove-shaped reader.
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
episode containing:
- ``episode_index``: int
- ``frame_timestamps``: tuple[float, ...]
- ``frame_indices``: tuple[int, ...]
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
- ``data_path``: pathlib.Path of the source parquet shard
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
This shape lets each module operate per-episode without loading all parquet
rows into memory at once.
"""
from __future__ import annotations
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import load_tasks
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
@dataclass
class EpisodeRecord:
"""Per-episode record yielded by the reader."""
episode_index: int
episode_task: str
frame_timestamps: tuple[float, ...]
frame_indices: tuple[int, ...]
data_path: Path
row_offset: int # row offset within the parquet file where this episode starts
row_count: int # number of rows for this episode
# Memoized parquet slice — populated on first ``frames_df()`` call so
# repeat queries from different modules don't re-read the whole shard.
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
def frames_df(self): # type: ignore[no-untyped-def]
"""Lazy-load the pandas slice for this episode (memoized)."""
if self._frames_df_cache is None:
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
table = pq.read_table(self.data_path)
df: pd.DataFrame = table.to_pandas()
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
drop=True
)
return self._frames_df_cache
def reconstruct_subtask_spans(
rows: Sequence[dict[str, Any]],
*,
episode_end_t: float | None = None,
) -> list[dict[str, Any]]:
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
Each span's ``end`` is the next span's ``start``. The final span's
``end`` defaults to its own ``start`` (zero-duration) — pass
``episode_end_t`` to extend it to the episode's last frame instead,
which is what downstream consumers (memory, interjection boundary
selection) expect.
Used by the ``plan`` module (plan-update pass) and the
``interjections`` module (interjection anchoring), which both need the
same span shape.
"""
sorted_rows = sorted(
(r for r in rows if r.get("style") == "subtask"),
key=lambda r: float(r["timestamp"]),
)
spans: list[dict[str, Any]] = []
for r in sorted_rows:
t = float(r["timestamp"])
if spans:
spans[-1]["end"] = t
spans.append({"text": r.get("content") or "", "start": t, "end": t})
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
spans[-1]["end"] = float(episode_end_t)
return spans
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
"""Snap an arbitrary float to the nearest exact source frame timestamp.
Modules use this when emitting event-style rows so the row's
timestamp matches a real parquet frame: event rows must land on an
exact frame, otherwise the per-frame event lookup the writer does
would never match them.
"""
if not frame_timestamps:
return float(t)
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
return float(nearest)
def _load_tasks_lookup(root: Path) -> dict[int, str]:
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
Returns an empty dict when the file is absent — the task description is
derived later from the video if needed. Reuses the library-level
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
frame indexed by task string with a ``task_index`` column.
"""
if not (root / DEFAULT_TASKS_PATH).exists():
return {}
tasks = load_tasks(root)
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
Episodes are yielded in ascending ``episode_index`` order. The reader does
not assume a specific chunk/file layout: it scans every ``*.parquet``
under ``data/`` and groups by ``episode_index``.
"""
tasks = _load_tasks_lookup(root)
data_dir = root / "data"
parquet_files = sorted(data_dir.rglob("*.parquet"))
only_set = set(only_episodes) if only_episodes is not None else None
for path in parquet_files:
yield from _iter_one_path(path, tasks, only_set)
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
table = pq.read_table(path)
names = table.column_names
if "episode_index" not in names:
return
episode_col = table.column("episode_index").to_pylist()
timestamp_col = (
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
)
frame_col = (
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
)
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
def _build(
ep: int,
start: int,
end: int,
task_idx: int | None,
ts_buf: list[float],
fi_buf: list[int],
) -> EpisodeRecord | None:
if only_set is not None and ep not in only_set:
return None
task = tasks.get(task_idx, "") if task_idx is not None else ""
return EpisodeRecord(
episode_index=ep,
episode_task=task,
frame_timestamps=tuple(ts_buf),
frame_indices=tuple(fi_buf),
data_path=path,
row_offset=start,
row_count=end - start,
)
cur_ep: int | None = None
start_offset = 0
ts_buf: list[float] = []
fi_buf: list[int] = []
cur_task_idx: int | None = None
for i, ep in enumerate(episode_col):
if cur_ep is None:
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
continue
if ep != cur_ep:
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
else:
ts_buf.append(timestamp_col[i])
fi_buf.append(frame_col[i])
if cur_ep is not None:
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
@@ -1,92 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Per-episode staging.
Each module writes its raw output as a JSONL file under
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
staging tree and partitions rows into the two language columns.
JSONL is preferred over parquet here because the staging artifact is meant to
be human-inspectable, easy to diff between prompt iterations, and trivially
appended to. The final dataset format is parquet; staging is just an
intermediate.
"""
from __future__ import annotations
import json
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
ModuleName = str
_MODULES: tuple[ModuleName, ...] = (
"plan",
"interjections",
"vqa",
)
@dataclass
class EpisodeStaging:
"""Filesystem layout for a single episode's staged module outputs."""
root: Path
episode_index: int
@property
def episode_dir(self) -> Path:
return self.root / f"episode_{self.episode_index:06d}"
def path_for(self, module: ModuleName) -> Path:
if module not in _MODULES:
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
return self.episode_dir / f"{module}.jsonl"
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
path = self.path_for(module)
path.parent.mkdir(parents=True, exist_ok=True)
# Atomic replace: a crash mid-write would otherwise leave a
# half-written JSONL file that ``read()`` would then fail to
# parse. Write to a sibling .tmp and rename so the target path
# only ever points at a complete file.
tmp_path = path.with_suffix(path.suffix + ".tmp")
with tmp_path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
f.write("\n")
tmp_path.replace(path)
return path
def read(self, module: ModuleName) -> list[dict[str, Any]]:
path = self.path_for(module)
if not path.exists():
return []
out: list[dict[str, Any]] = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
out.append(json.loads(line))
return out
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
return {m: self.read(m) for m in _MODULES}
def has(self, module: ModuleName) -> bool:
return self.path_for(module).exists()
@@ -1,332 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre-write validation against staged outputs.
Runs after all three modules have written their per-episode artifacts but
*before* the writer rewrites parquet shards. The validator never touches
parquet; it only inspects the staging tree and the source frame timestamps
exposed by :class:`EpisodeRecord`.
Checks (per the plan's "Intermediate staging and validation" section):
- exact timestamp alignment against source frame timestamps
- no orphan speech / interjection pairs
- plan / memory emission consistency (events have a paired persistent row)
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
attribute / spatial)
- every row maps to its correct column under :func:`column_for_style`
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import (
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
column_for_style,
is_view_dependent_style,
validate_camera_field,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
@dataclass
class ValidationReport:
"""Outcome of one validation pass across all episodes."""
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
episodes_checked: int = 0
@property
def ok(self) -> bool:
return not self.errors
def add_error(self, message: str) -> None:
self.errors.append(message)
def add_warning(self, message: str) -> None:
self.warnings.append(message)
def summary(self) -> str:
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
def classify_vqa_answer(payload: Any) -> str | None:
"""Best-effort classification of a VQA answer payload to a question type."""
if not isinstance(payload, dict):
return None
keys = set(payload.keys())
for kind, required in VQA_ANSWER_SHAPES.items():
if required.issubset(keys):
return kind
return None
@dataclass
class StagingValidator:
"""Walks the staging tree and produces a :class:`ValidationReport`."""
timestamp_atol: float = 0.0 # exact-match by default
dataset_camera_keys: tuple[str, ...] | None = None
"""Known ``observation.images.*`` keys on the dataset. When set, the
validator additionally enforces that every view-dependent row's
``camera`` field references one of these keys. Pass ``None`` (default)
to skip that cross-check (e.g. in unit tests with no real dataset)."""
def validate(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
) -> ValidationReport:
report = ValidationReport()
for record in records:
self._validate_episode(record, staging_dir, report)
report.episodes_checked += 1
return report
def _validate_episode(
self,
record: EpisodeRecord,
staging_dir: Path,
report: ValidationReport,
) -> None:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged = staging.read_all()
all_rows: list[dict[str, Any]] = []
for module_name, rows in staged.items():
for row in rows:
row = {**row, "_module": module_name}
all_rows.append(row)
frame_ts = set(record.frame_timestamps)
events: list[dict[str, Any]] = []
persistent: list[dict[str, Any]] = []
for row in all_rows:
self._check_column_routing(row, report, record.episode_index)
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
# ``_check_column_routing`` already recorded any unknown-style error;
# don't let the same ``column_for_style`` lookup raise here uncaught.
try:
column = column_for_style(row.get("style"))
except ValueError:
continue
if column == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
events.append(row)
for row in events:
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
self._check_speech_interjection_pairs(events, report, record.episode_index)
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
self._check_vqa_json(events, report, record.episode_index)
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
def _check_camera_field(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
dataset_camera_keys: Sequence[str] | None,
) -> None:
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
style = row.get("style")
camera = row.get("camera")
try:
validate_camera_field(style, camera)
except ValueError as exc:
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
return
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
report.add_error(
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
)
def _check_vqa_uniqueness_per_frame_camera(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
counts: dict[tuple[float, str, str], int] = {}
for row in events:
if row.get("style") != "vqa":
continue
ts = row.get("timestamp")
camera = row.get("camera")
role = row.get("role")
if ts is None or camera is None or role is None:
continue # other validators flag these
key = (float(ts), str(camera), str(role))
counts[key] = counts.get(key, 0) + 1
for (ts, camera, role), n in counts.items():
if n > 1:
report.add_error(
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
)
def _check_column_routing(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
) -> None:
style = row.get("style")
module = row.get("_module")
try:
target_col = column_for_style(style)
except ValueError:
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
return
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
report.add_error(
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
)
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
report.add_error(
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
)
def _check_event_timestamp_alignment(
self,
row: dict[str, Any],
frame_ts: set[float],
report: ValidationReport,
episode_index: int,
) -> None:
ts = row.get("timestamp")
if ts is None:
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
return
if self.timestamp_atol == 0.0:
if float(ts) not in frame_ts:
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
)
else:
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
)
def _check_speech_interjection_pairs(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
speech_ts: dict[float, int] = {}
interjection_ts: dict[float, int] = {}
for row in events:
ts = row.get("timestamp")
if ts is None:
continue
ts_f = float(ts)
if row.get("style") is None and row.get("role") == "assistant":
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
if row.get("style") == "interjection":
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
for ts in interjection_ts:
if ts not in speech_ts:
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
def _check_plan_memory_consistency(
self,
persistent: Sequence[dict[str, Any]],
events: Sequence[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
interjection_ts = sorted(
{
float(r["timestamp"])
for r in events
if r.get("style") == "interjection" and r.get("timestamp") is not None
}
)
if persistent and not plan_ts:
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
# every interjection should have a same-timestamp plan refresh
for ts in interjection_ts:
if ts not in set(plan_ts):
report.add_error(
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
)
# memory should be emitted at subtask boundaries (subset relation)
if memory_ts and subtask_ts:
mem_set = set(memory_ts)
sub_set = set(subtask_ts)
stray = sorted(mem_set - sub_set)
if stray:
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
def _check_vqa_json(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
for row in events:
if row.get("style") != "vqa" or row.get("role") != "assistant":
continue
content = row.get("content")
if content is None:
report.add_error(
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
)
continue
try:
payload = json.loads(content)
except (TypeError, ValueError) as exc:
report.add_error(
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
)
continue
shape = classify_vqa_answer(payload)
if shape is None:
report.add_error(
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
)
@@ -1,599 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared Qwen-VL client.
The pipeline uses a single shared VLM across modules. vLLM is preferred when
available (high throughput, JSON-guided decoding); transformers is the
fallback. A ``stub`` backend is used for unit tests so fixtures never call
into a real model.
The client speaks one method, :meth:`VlmClient.generate_json`, which:
- accepts a list of OpenAI/HF-style multimodal messages,
- requests JSON output from the server,
- batches requests transparently,
- and reprompts once on a JSON parse failure with an inline correction
message before raising.
"""
from __future__ import annotations
import atexit
import base64
import io
import json
import os
import shlex
import signal
import subprocess
import sys
import threading
import time
import urllib.request
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Protocol
from .config import VlmConfig
class VlmClient(Protocol):
"""Protocol every backend must implement."""
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
"""Generate one JSON-decoded response per messages list."""
@dataclass
class StubVlmClient:
"""Deterministic stub used in unit tests.
A test passes a callable that maps the *last user message text* (or, if
that is empty, the full message list) to a JSON-serializable response.
"""
responder: Callable[[Sequence[dict[str, Any]]], Any]
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
return [self.responder(list(messages)) for messages in messages_batch]
def _strip_to_json(text: str) -> Any:
text = text.strip()
# Strip <think>...</think> blocks (Qwen3 Thinking style)
while "<think>" in text and "</think>" in text:
start = text.find("<think>")
end = text.find("</think>", start) + len("</think>")
text = (text[:start] + text[end:]).strip()
# Strip ```json ... ``` fences from chat-tuned backbones
if text.startswith("```"):
first = text.find("\n")
last = text.rfind("```")
if first != -1 and last != -1 and last > first:
text = text[first + 1 : last].strip()
try:
return json.loads(text)
except (ValueError, json.JSONDecodeError):
pass
# Fall back to extracting the first balanced {...} block.
obj_text = _extract_first_json_object(text)
if obj_text is None:
raise json.JSONDecodeError("No JSON object found", text, 0)
return json.loads(obj_text)
def _extract_first_json_object(text: str) -> str | None:
"""Return the first balanced ``{...}`` substring, ignoring braces in
string literals. Returns ``None`` if no balanced block is found."""
start = text.find("{")
if start < 0:
return None
depth = 0
in_string = False
escape = False
for i in range(start, len(text)):
ch = text[i]
if escape:
escape = False
continue
if ch == "\\":
escape = True
continue
# Note: ``escape`` is always False here — the ``if escape`` branch
# above already handled and reset it.
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
@dataclass
class _GenericTextClient:
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
config: VlmConfig
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
temp = temperature if temperature is not None else self.config.temperature
raw = self.generate_text(messages_batch, max_tok, temp)
out: list[Any] = []
for messages, text in zip(messages_batch, raw, strict=True):
try:
out.append(_strip_to_json(text))
continue
except (ValueError, json.JSONDecodeError):
pass
retry = list(messages) + [
{"role": "assistant", "content": text},
{
"role": "user",
"content": (
"Your previous reply was not valid JSON. "
"Reply with strictly valid JSON, no prose, no fences."
),
},
]
retry_text = self.generate_text([retry], max_tok, temp)[0]
try:
out.append(_strip_to_json(retry_text))
except (ValueError, json.JSONDecodeError):
# After retry: log preview and return None instead of crashing
# the whole pipeline. Modules treat None as "skip".
preview = retry_text.strip().replace("\n", " ")[:200]
print(
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
flush=True,
)
out.append(None)
return out
def make_vlm_client(config: VlmConfig) -> VlmClient:
"""Build the shared VLM client.
Only the ``openai`` backend is supported for now. The shipped workflow
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
optionally auto-spawning the server via ``auto_serve`` /
``serve_command``). The former in-process ``vllm`` / ``transformers``
backends were removed to keep the support surface to the HF Jobs path.
For ``stub``, construct :class:`StubVlmClient` directly with a responder
callable; it is rejected here to make accidental misuse obvious.
"""
if config.backend == "openai":
return _make_openai_client(config)
if config.backend == "stub":
raise ValueError(
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
)
if config.backend in {"vllm", "transformers"}:
raise ValueError(
f"backend={config.backend!r} (in-process local model) is not supported for now — "
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
)
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
def _make_openai_client(config: VlmConfig) -> VlmClient:
"""Backend that talks to any OpenAI-compatible server.
Compatible with ``vllm serve``, ``transformers serve``,
``ktransformers serve``, and hosted endpoints. By default the server
is expected to be already running. Set ``auto_serve=True`` to have
this client spawn one (default: ``transformers serve``), wait until
it's ready, and tear it down on process exit.
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
auto-converted to ``image_url`` data-URLs. Video blocks
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
multi-frame ``video_url`` items where supported.
"""
try:
from openai import OpenAI # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError(
"openai package is required for backend='openai'. Install with `pip install openai`."
) from exc
api_base = config.api_base
api_key = config.api_key
auto_serve = config.auto_serve
api_bases: list[str] = [api_base]
print(
f"[lerobot-annotate] backend=openai model={config.model_id} "
f"api_base={api_base} auto_serve={auto_serve}",
flush=True,
)
if auto_serve:
if config.parallel_servers > 1:
print(
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
flush=True,
)
api_bases = _spawn_parallel_inference_servers(config)
elif _server_is_up(api_base):
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
else:
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
api_base = _spawn_inference_server(config)
api_bases = [api_base]
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
# round-robin counter for parallel mode
rr_counter = {"i": 0}
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
# rejects it with HTTP 422. Send it only when explicitly opted in via
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
rr_lock = threading.Lock()
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
api_messages, mm_kwargs = _to_openai_messages(messages)
kwargs: dict[str, Any] = {
"model": config.model_id,
"messages": api_messages,
"max_tokens": max_tok,
"temperature": temp,
}
extra_body: dict[str, Any] = {}
if send_mm_kwargs and mm_kwargs:
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
if config.chat_template_kwargs:
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
if extra_body:
kwargs["extra_body"] = extra_body
with rr_lock:
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
if len(batch) <= 1 or config.client_concurrency <= 1:
return [_one_call(messages, max_tok, temp) for messages in batch]
# Parallel fan-out — vllm batches these on the server side.
max_workers = min(config.client_concurrency, len(batch))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
return [f.result() for f in futures]
return _GenericTextClient(_gen, config)
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
"""Spawn ``config.parallel_servers`` independent vllm replicas.
Each replica:
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
- listens on ``serve_port + i``
- is shut down via the same atexit hook as the single-server path
Returns the list of ``api_base`` URLs the client should round-robin
across.
"""
n = config.parallel_servers
api_bases: list[str] = []
procs: list[subprocess.Popen] = []
ready_events: list[threading.Event] = []
# Multiple readiness signals — uvicorn's own banner is suppressed at
# ``--uvicorn-log-level warning``, so we also accept vllm's own
# "Starting vLLM API server" line and the route-listing line. The
# HTTP probe below is the ultimate fallback.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
# Single lock for all server-stream threads so multibyte chars from
# different servers don't interleave and tear UTF-8 sequences.
print_lock = threading.Lock()
base_cmd = config.serve_command or (
f"vllm serve {shlex.quote(config.model_id)} "
f"--tensor-parallel-size 1 "
f"--max-model-len {config.max_model_len or 32768} "
f"--uvicorn-log-level warning"
)
num_gpus = config.num_gpus if config.num_gpus > 0 else n
for i in range(n):
port = config.serve_port + i
gpu = i % num_gpus
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
api_base = f"http://localhost:{port}/v1"
api_bases.append(api_base)
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
)
procs.append(proc)
ready = threading.Event()
ready_events.append(ready)
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
# Read whole lines and emit each line atomically under the
# shared print_lock so output from N servers stays readable.
assert p.stdout is not None
for line in iter(p.stdout.readline, ""):
with print_lock:
sys.stdout.write(f"[server-{idx}] {line}")
if not line.endswith(("\n", "\r")):
sys.stdout.write("\n")
sys.stdout.flush()
if any(m in line for m in ready_markers):
ev.set()
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
while not ev.is_set() and p.poll() is None:
if _server_is_up(base):
print(f"[server-{idx}] ready (http probe)", flush=True)
ev.set()
return
time.sleep(2)
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
def _shutdown() -> None:
for i, p in enumerate(procs):
if p.poll() is None:
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
p.send_signal(signal.SIGINT)
for p in procs:
try:
p.wait(timeout=15)
except subprocess.TimeoutExpired:
p.kill()
p.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
for i, p in enumerate(procs):
if p.poll() is not None:
raise RuntimeError(
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
)
time.sleep(2)
if any(not ev.is_set() for ev in ready_events):
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
return api_bases
def _server_is_up(api_base: str) -> bool:
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
url = api_base.rstrip("/") + "/models"
# ``api_base`` is the user-configured local-server URL we just spawned
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
# is for arbitrary user-controlled URLs with file:/ schemes which
# cannot reach this code path.
try:
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
return resp.status == 200
except Exception: # noqa: BLE001
return False
def _spawn_inference_server(config: VlmConfig) -> str:
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
accepts ``/v1/models``, and register a shutdown hook.
Streams the server's stdout/stderr to the parent terminal in
real-time on a background thread so users can see model-load
progress and errors as they happen.
Returns the full ``api_base`` URL the OpenAI client should use.
"""
cmd = config.serve_command
if not cmd:
cmd = (
f"transformers serve {shlex.quote(config.model_id)} "
f"--port {config.serve_port} --continuous-batching"
)
api_base = f"http://localhost:{config.serve_port}/v1"
print(f"[server] launching: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
# Watch the server output for the uvicorn readiness banner. This is
# more reliable than polling /v1/models because transformers serve
# rescans its cache on every model-list request, which can exceed
# the urllib timeout and trigger an infinite probe loop.
ready_event = threading.Event()
# See _spawn_parallel_inference_servers for why we accept these.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
def _probe() -> None:
while not ready_event.is_set() and proc.poll() is None:
if _server_is_up(api_base):
print("[server] ready (http probe)", flush=True)
ready_event.set()
return
time.sleep(2)
threading.Thread(target=_probe, daemon=True).start()
def _stream_output() -> None:
# Read raw chunks instead of iterating lines so tqdm progress
# bars (which overwrite using \r) flush in real time.
assert proc.stdout is not None
buf = ""
prefix_started = False
while True:
ch = proc.stdout.read(1)
if ch == "":
# process exited; flush any tail
if buf:
sys.stdout.write(buf)
sys.stdout.flush()
return
if not prefix_started:
sys.stdout.write("[server] ")
prefix_started = True
sys.stdout.write(ch)
sys.stdout.flush()
buf += ch
if ch in ("\n", "\r"):
if any(marker in buf for marker in ready_markers):
ready_event.set()
buf = ""
prefix_started = False
threading.Thread(target=_stream_output, daemon=True).start()
def _shutdown() -> None:
if proc.poll() is None:
print(f"[server] stopping pid={proc.pid}", flush=True)
proc.send_signal(signal.SIGINT)
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while time.monotonic() < deadline:
if proc.poll() is not None:
raise RuntimeError(
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
f"See [server] log lines above for the cause."
)
if ready_event.wait(timeout=2):
return api_base
proc.terminate()
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
def _to_openai_messages(
messages: Sequence[dict[str, Any]],
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
"""Convert internal messages to OpenAI chat format.
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
(``fps`` from ``video_url`` blocks) are extracted out so the caller
can pass them via ``extra_body.mm_processor_kwargs`` rather than
inside the content blocks (which transformers serve rejects).
File-URL video blocks are inlined as base64 data URLs.
"""
out_messages: list[dict[str, Any]] = []
mm_kwargs: dict[str, Any] = {}
for message in messages:
content = message.get("content")
if not isinstance(content, list):
out_messages.append({"role": message["role"], "content": content})
continue
out_blocks: list[dict[str, Any]] = []
for block in content:
block_type = block.get("type") if isinstance(block, dict) else None
if block_type == "text":
out_blocks.append({"type": "text", "text": block.get("text", "")})
elif block_type == "image":
out_blocks.append(
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
)
elif block_type == "video":
frames = block.get("video", [])
for img in frames:
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
elif block_type == "video_url":
video_url = dict(block["video_url"])
url = video_url.get("url", "")
if url.startswith("file://"):
video_url["url"] = _file_to_data_url(url[len("file://") :])
out_blocks.append({"type": "video_url", "video_url": video_url})
fps = block.get("fps")
if fps is not None:
mm_kwargs["fps"] = fps
else:
out_blocks.append(block)
out_messages.append({"role": message["role"], "content": out_blocks})
return out_messages, mm_kwargs
def _file_to_data_url(path: str) -> str:
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
return f"data:video/mp4;base64,{b64}"
def _pil_to_data_url(image: Any) -> str:
"""Encode a PIL.Image as a base64 data URL."""
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/png;base64,{b64}"
@@ -1,341 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Final parquet rewrite.
For every episode the writer:
1. reads the staged module outputs,
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
3. sorts each slice deterministically,
4. broadcasts the persistent slice across every frame in the episode,
5. for each frame, materializes the sublist of event rows whose timestamp
exactly equals that frame's timestamp,
6. drops the legacy ``subtask_index`` column,
7. writes the parquet shard back in place.
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
struct for every speech atom. The tool *schema* (the description
of the ``say`` function and its parameters) is a fixed code constant —
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
it directly rather than reading a redundant per-row column.
Invariants enforced here (and re-checked by the validator):
- per-episode persistent slice is byte-identical across every frame;
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
(timestamps come straight from the source parquet — never recomputed);
- every row passes ``column_for_style(style)``.
"""
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
validate_camera_field,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
# Tool schema constants live in lerobot.datasets.language — single
# source of truth. Re-exported here so existing imports
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
# keep working.
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
# events are bucketed per-frame, but within a frame we still want determinism
return (
row.get("style") or "",
row.get("role") or "",
row.get("camera") or "",
)
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
"""Coerce a staged row into the language-column struct shape.
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
writer infers the parquet struct schema from insertion order, so
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
"""
camera = row.get("camera")
validate_camera_field(style, camera)
out: dict[str, Any] = {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
}
if with_timestamp:
out["timestamp"] = float(row["timestamp"])
out["camera"] = None if camera is None else str(camera)
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
return out
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the persistent column's struct shape."""
style = row.get("style")
if style not in PERSISTENT_STYLES:
raise ValueError(
f"persistent slice contains row with non-persistent style {style!r}; "
"row would be misrouted under column_for_style()"
)
if "timestamp" not in row:
raise ValueError(f"persistent row missing timestamp: {row!r}")
if "role" not in row:
# Friendly error from the writer instead of a raw KeyError below;
# the validator doesn't check ``role`` yet.
raise ValueError(f"persistent row missing role: {row!r}")
return _normalize_row(row, style, with_timestamp=True)
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
style = row.get("style")
if style is not None and style not in EVENT_ONLY_STYLES:
raise ValueError(
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
)
if column_for_style(style) != LANGUAGE_EVENTS:
raise ValueError(f"event row with style {style!r} would not route to language_events")
if "role" not in row:
raise ValueError(f"event row missing role: {row!r}")
return _normalize_row(row, style, with_timestamp=False)
def _normalize_tool_calls(value: Any) -> list[Any] | None:
if value is None:
return None
if not isinstance(value, list):
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
return list(value)
def _validate_atom_invariants(row: dict[str, Any]) -> None:
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
has_content = row.get("content") is not None
has_tools = row.get("tool_calls") is not None
if not (has_content or has_tools):
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
if row.get("style") is None and not has_tools:
raise ValueError(f"style=None requires tool_calls: {row!r}")
def _validate_speech_atom(row: dict[str, Any]) -> None:
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
if row.get("style") is not None:
return # not a speech atom
if row.get("role") != "assistant":
raise ValueError(f"speech atom must have role=assistant: {row!r}")
if row.get("content") is not None:
raise ValueError(f"speech atom must have content=null: {row!r}")
tool_calls = row.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
first = tool_calls[0]
if not isinstance(first, dict):
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
if first.get("type") != "function":
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
fn = first.get("function") or {}
if fn.get("name") != "say":
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
args = fn.get("arguments") or {}
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
@dataclass
class LanguageColumnsWriter:
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
drop_existing_subtask_index: bool = True
def write_all(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> list[Path]:
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
for record in records:
episodes_by_path[record.data_path].append(record)
written: list[Path] = []
for path, eps in episodes_by_path.items():
self._rewrite_one(path, eps, staging_dir, root)
written.append(path)
return written
def _rewrite_one(
self,
path: Path,
episodes: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> None:
table = pq.read_table(path)
n_rows = table.num_rows
# Ensure we cover every episode in the file. Episodes that don't have
# staging artifacts are passed through with empty annotation lists —
# this keeps the writer idempotent and safe for partial reruns.
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
for record in episodes:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged_per_ep[record.episode_index] = staging.read_all()
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
for ep_index, ep_staged in staged_per_ep.items():
persistent_rows: list[dict[str, Any]] = []
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
for _module_name, rows in ep_staged.items():
for row in rows:
style = row.get("style")
if column_for_style(style) == LANGUAGE_PERSISTENT:
persistent_rows.append(row)
else:
event_rows.append(row)
persistent_rows.sort(key=_row_persistent_sort_key)
normalized_persistent = []
for r in persistent_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
normalized_persistent.append(_normalize_persistent_row(r))
persistent_by_ep[ep_index] = normalized_persistent
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
for r in event_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
ts = float(r["timestamp"])
buckets[ts].append(_normalize_event_row(r))
for ts in list(buckets.keys()):
buckets[ts].sort(key=_row_event_sort_key)
events_by_ep_ts[ep_index] = buckets
episode_col = (
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
)
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
if episode_col is None or ts_col is None:
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
per_row_persistent: list[list[dict[str, Any]]] = []
per_row_events: list[list[dict[str, Any]]] = []
for i in range(n_rows):
ep = episode_col[i]
ts = float(ts_col[i])
per_row_persistent.append(persistent_by_ep.get(ep, []))
buckets = events_by_ep_ts.get(ep, {})
per_row_events.append(buckets.get(ts, []))
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
# Atomic replace: write to a sibling tmp path and rename so a crash
# mid-write can't leave a half-written shard that ``pq.read_table``
# would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp")
pq.write_table(new_table, tmp_path)
tmp_path.replace(path)
def _materialize_table(
self,
table: pa.Table,
persistent: list[list[dict[str, Any]]],
events: list[list[dict[str, Any]]],
*,
drop_old: bool,
) -> pa.Table:
cols = []
names = []
for name in table.column_names:
if drop_old and name == "subtask_index":
continue
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
continue # we'll re-add canonical versions
# Strip any legacy ``tools`` column previously emitted by older
# writers — the schema no longer uses it (constant lives in
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
if name == "tools":
continue
cols.append(table.column(name))
names.append(name)
# We let pyarrow infer struct/list schema rather than passing the
# canonical type from `lerobot.datasets.language` directly: that type
# uses `pa.json_()` for the `tool_calls` element type, which
# `pa.array(..., type=...)` cannot materialize from Python lists on
# current pyarrow versions. The inferred schema round-trips through
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
# exercises the same flow.
persistent_arr = pa.array(persistent)
events_arr = pa.array(events)
cols.extend([persistent_arr, events_arr])
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
return pa.Table.from_arrays(cols, names=names)
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
"""Build a canonical speech tool-call atom for the events column."""
return {
"role": "assistant",
"content": None,
"style": None,
"timestamp": float(timestamp),
"camera": None,
"tool_calls": [
{
"type": "function",
"function": {
"name": "say",
"arguments": {"text": text},
},
}
],
}
-70
View File
@@ -18,7 +18,6 @@ from __future__ import annotations
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
@@ -244,72 +243,3 @@ def sanity_check_dataset_robot_compatibility(
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)
########################################################################################
# Teleoperator smooth handover helpers
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
# being a root module.
########################################################################################
def teleop_supports_feedback(teleop) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback (i.e. have non-empty
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
``target_pos`` is expected to be in the teleop's action/feedback key space.
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
the robot action key space directly.
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
follower robot does not receive new actions, which could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def follower_smooth_move_to(
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm to
the follower, the follower is brought to the teleop's current pose so the
robot meets the operator's hand rather than jumping to it on the first frame.
Both ``current`` and ``target`` must be in the robot action key space
(i.e. the output of ``robot_action_processor``).
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
-146
View File
@@ -205,149 +205,3 @@ class WandBLogger:
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
def log_training_examples(
self,
batch: dict,
step: int,
*,
camera_keys: list[str],
n_samples: int = 4,
policy=None,
predict_actions: bool = False,
mode: str = "train",
) -> None:
"""Push a ``wandb.Table`` of training-example rows for the current batch.
Each row is one batch element with:
* one ``wandb.Image`` column per camera in ``camera_keys`` (CHW or
HWC, uint8 or float in [0,1] — auto-detected),
* any text fields present in the batch (``task`` / ``subtask`` /
``memory`` / ``instruction``),
* ground-truth action first/last frame (the action chunk's
endpoints — gives a quick sense of trajectory direction),
* if ``predict_actions=True`` and ``policy`` is supplied, the model's
``predict_action_chunk`` first/last frame alongside.
This is opt-in via ``--wandb.log_examples_freq=N`` on the CLI; the
training loop calls it once every N steps. Cheap to keep on: with
N=4 samples and 3 cameras you upload 12 small PNGs per dump and (if
enabled) run one extra inference forward pass.
"""
import logging # noqa: PLC0415
import numpy as np # noqa: PLC0415
import torch # noqa: PLC0415
if mode not in {"train", "eval"}:
raise ValueError(mode)
# Batch size — first tensor-like value wins.
bsz = next(
(int(v.shape[0]) for v in batch.values() if hasattr(v, "shape") and v.ndim > 0),
None,
)
if not bsz:
return
n = min(int(n_samples), bsz)
# Optional predicted-action forward pass on the first n samples.
pred_actions: np.ndarray | None = None
if predict_actions and policy is not None:
was_training = policy.training
try:
policy.eval()
sub_batch = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
sub_batch[k] = v[:n]
elif isinstance(v, (list, tuple)):
sub_batch[k] = list(v[:n])
else:
sub_batch[k] = v
with torch.no_grad():
pred = policy.predict_action_chunk(sub_batch)
pred_actions = pred.detach().cpu().float().numpy()
except Exception as exc: # noqa: BLE001
logging.warning(
"log_training_examples: predict_action_chunk failed (%s) — "
"skipping predicted-action columns",
exc,
)
pred_actions = None
finally:
if was_training:
policy.train()
present_cameras = [c for c in camera_keys if c in batch]
text_keys = [k for k in ("task", "subtask", "memory", "instruction") if k in batch]
columns = ["sample"]
columns.extend(c.removeprefix("observation.images.") or c for c in present_cameras)
columns.extend(text_keys)
columns.append("gt_action_first")
columns.append("gt_action_last")
if pred_actions is not None:
columns.append("pred_action_first")
columns.append("pred_action_last")
table = self._wandb.Table(columns=columns)
def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray:
# Strip an outer time dim if present: (T, C, H, W) -> first frame.
if t.ndim == 4:
t = t[0]
# CHW -> HWC.
if t.ndim == 3 and t.shape[0] in (1, 3, 4) and t.shape[-1] not in (1, 3, 4):
t = t.permute(1, 2, 0)
arr = t.detach().cpu().float().numpy()
if arr.size and float(arr.max()) <= 1.5:
arr = arr * 255.0
return np.clip(arr, 0, 255).astype(np.uint8)
def _action_endpoints(a: torch.Tensor) -> tuple[str, str]:
arr = a.detach().cpu().float().numpy()
if arr.ndim == 2: # (T, D)
return (
str(np.round(arr[0], 3).tolist()),
str(np.round(arr[-1], 3).tolist()),
)
if arr.ndim == 1:
rounded = np.round(arr, 3).tolist()
return (str(rounded), str(rounded))
return (str(arr.tolist()), str(arr.tolist()))
for i in range(n):
row: list = [i]
for cam in present_cameras:
try:
row.append(self._wandb.Image(_to_uint8_hwc(batch[cam][i])))
except Exception as exc: # noqa: BLE001
logging.warning(
"log_training_examples: camera %s sample %d failed (%s)",
cam,
i,
exc,
)
row.append(None)
for tk in text_keys:
v = batch[tk]
if isinstance(v, (list, tuple)):
row.append(str(v[i]) if i < len(v) else "")
else:
row.append(str(v))
action = batch.get("action")
if isinstance(action, torch.Tensor) and action.ndim >= 1:
first, last = _action_endpoints(action[i])
row.append(first)
row.append(last)
else:
row.append("")
row.append("")
if pred_actions is not None:
p = torch.from_numpy(pred_actions[i])
pfirst, plast = _action_endpoints(p)
row.append(pfirst)
row.append(plast)
table.add_data(*row)
self._wandb.log({f"{mode}/examples": table}, step=step)
+2 -2
View File
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
private: bool | None = None
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
-66
View File
@@ -62,72 +62,6 @@ class WandBConfig:
run_id: str | None = None
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
add_tags: bool = True # If True, save configuration as tags in the WandB run.
# Periodic training-example dump (independent of ``log_freq``). When > 0,
# every ``log_examples_freq`` steps the trainer pushes a ``wandb.Table``
# with one row per sampled batch element containing each camera view
# (rendered as ``wandb.Image``), any text fields present in the batch
# (``task`` / ``subtask`` / ``memory`` / ``instruction``), and the
# ground-truth action chunk's first + last frames. Defaults to 5000 — set
# to 0 to disable. Only fires when ``enable=True``, so runs without wandb
# are unaffected.
log_examples_freq: int = 5000
# Number of batch elements to include in each example dump.
log_examples_n: int = 4
# If True (default), also run ``policy.predict_action_chunk`` on the logged
# samples (in eval mode, no_grad) and add predicted vs ground-truth action
# columns to the table. Costs one extra forward pass per dump — negligible
# at the 5k-step default cadence. Set to ``False`` if your policy doesn't
# implement ``predict_action_chunk`` or you want to skip the extra forward.
log_examples_predict_actions: bool = True
@dataclass
class EMAConfig:
"""Exponential Moving Average of trainable policy parameters.
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
pi052) benefit substantially from averaging late-training
parameter oscillations — see Chi et al. 2023 §V.D. The official
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
``0.999`` for its pi05_libero config; the openpi PyTorch port
explicitly lists EMA as unsupported, and LeRobot main inherited
that gap. Enabling this flag plugs ema-pytorch
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
training loop with a shadow ``nn.Module`` clone of the policy.
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
params) + one elementwise update per training step (~1% step time).
Off by default (opt-in): EMA is only beneficial for flow-matching /
diffusion policies (pi0/pi05/pi052), and the fp32 shadow copy is pure
overhead for other policies (e.g. VLA-JEPA). Set ``--ema.enable=true``
to turn it on (the pi05/pi052 training recipes do this). openpi (JAX)
ships EMA on for every config; enable it explicitly to match that.
"""
enable: bool = False
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
# ema-pytorch as ``beta``).
# 0.999 — last ~1000 steps; pi05_libero default in openpi
# 0.99 — last ~100 steps; openpi top-level default
# 0.75 — very fast EMA (Diffusion Policy original setting)
# 0.9999 — very slow EMA (long classification runs)
decay: float = 0.99
# Skip the first N calls to ``ema.update()``; during this window
# the shadow is just a hard copy of the live weights (no averaging).
# Lets early-training rapid changes settle before averaging begins.
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
# ramp like older lerobot EMA implementations).
warmup_steps: int = 0
# When True, the periodic eval block uses the EMA shadow model
# directly (``ema.ema_model``) instead of the live policy. Standard
# practice for diffusion-style policies — eval scores are usually
# 13% higher than the live policy at the same step.
use_for_eval: bool = True
# When True, the periodic wandb training-example dump uses the EMA
# shadow for the optional predicted-action columns (so what you see
# in W&B matches eval behavior).
use_for_wandb_examples: bool = True
@dataclass
+3 -18
View File
@@ -147,16 +147,7 @@ class TrainingRecipe:
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and the recipe supervises something.
A recipe is valid if it has at least one of:
* a ``target: true`` assistant turn (drives text-CE supervision), or
* a ``stream: low_level`` turn (drives flow / action supervision via
``predict_actions=True``, even when no assistant turn is targeted —
e.g. π0.5-style ``low_level_execution`` where the action expert
conditions on a user-only ``${subtask}`` prompt).
"""
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
@@ -165,14 +156,8 @@ class TrainingRecipe:
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
has_target = any(turn.target for turn in self.messages)
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
if not (has_target or has_low_level):
raise ValueError(
"Message recipes must contain at least one supervised turn — "
"either ``target: true`` (text CE) or ``stream: low_level`` "
"(flow/action loss)."
)
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
@@ -1,68 +0,0 @@
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
#
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
# training, and adds two text-supervised tasks:
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# memory_update — compress progress into a memory note.
# user_interjection_response — reply to a user interjection with a
# spoken `say` tool call (no plan, no
# subtask text — just the spoken reply).
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# Plan is intentionally left out — memory is the only persistent
# high-level state here, keeping the prompt short.
#
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
# annotations (the annotation pipeline's memory + interjection modules)
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
high_level_subtask:
weight: 0.30
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.55
messages:
# The action expert is conditioned on the SUBTASK — at inference
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
# here. `stream: low_level` flips `predict_actions=True` so the
# flow loss fires; no text-CE target (subtask prediction is owned
# by `high_level_subtask`).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# At inference, `MemoryUpdateFwd` is triggered only on
# `subtask_change` events (sparse). Training densely with
# `active_at` — i.e. on every frame inside a subtask interval,
# not just the boundary frame — supervises the same
# (prior_memory, completed_subtask) → current_memory mapping
# against varied observations within the interval. The model
# learns a stateless transformation; the *when* to emit lives in
# the inference trigger, not the model. Annotations only exist
# for ~1% of frames as boundary events, so `emitted_at` would
# waste 99% of the blend draws (and silently leak them into a
# task-conditioned fallback); `active_at` lifts the renderable
# rate to ~87% on this dataset.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
@@ -1,99 +0,0 @@
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
#
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
# camera-grounded VQA across the three RoboCasa camera keys produced
# by ``slurm_build_robocasa_composite_seen.py``:
#
# observation.images.robot0_agentview_left (left scene view)
# observation.images.robot0_agentview_right (right scene view)
# observation.images.robot0_eye_in_hand (wrist)
#
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
# VQA per camera, so each anchor frame produces three (user, assistant)
# rows tagged with their source camera. Each VQA sub-recipe consumes
# the rows for one camera via ``camera=...`` resolver bindings.
#
# Spatial VQA targets (bbox / point) are rewritten from JSON to
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
# ``register_paligemma_loc_tokens`` already collapses them to single
# detection-vocab ids so the LM head learns the pretrained pointing /
# detection prior, not a 7-piece BPE salad.
#
# Interjections / spoken responses are intentionally absent — the
# annotation job runs with ``--interjections.enabled=false``.
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.45
messages:
# Action expert is conditioned on the SUBTASK; at inference the
# high-level loop generates it via the LM head and feeds it here.
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
# loss fires; subtask CE is owned by ``high_level_subtask``.
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# Trained densely with ``active_at`` — every frame inside a subtask
# interval — so the (prior_memory, completed_subtask) → current_memory
# mapping is supervised against varied observations. The *when* to
# emit lives in the inference trigger (subtask_change), not the
# model. See ``subtask_mem.yaml`` for the long version of this note.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
ask_vqa_agentview_left:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_left}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_agentview_right:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_right}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_eye_in_hand}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
@@ -1,114 +0,0 @@
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
#
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
# training, and adds two text-supervised tasks:
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# memory_update — compress progress into a memory note.
# user_interjection_response — reply to a user interjection with a
# spoken `say` tool call (no plan, no
# subtask text — just the spoken reply).
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# Plan is intentionally left out — memory is the only persistent
# high-level state here, keeping the prompt short.
#
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
# annotations (the annotation pipeline's memory + interjection modules)
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.40
messages:
# The action expert is conditioned on the SUBTASK — at inference
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
# here. `stream: low_level` flips `predict_actions=True` so the
# flow loss fires; no text-CE target (subtask prediction is owned
# by `high_level_subtask`).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# At inference, `MemoryUpdateFwd` is triggered only on
# `subtask_change` events (sparse). Training densely with
# `active_at` — i.e. on every frame inside a subtask interval,
# not just the boundary frame — supervises the same
# (prior_memory, completed_subtask) → current_memory mapping
# against varied observations within the interval. The model
# learns a stateless transformation; the *when* to emit lives in
# the inference trigger, not the model. Annotations only exist
# for ~1% of frames as boundary events, so `emitted_at` would
# waste 99% of the blend draws (and silently leak them into the
# task-conditioned fallback); `active_at` lifts the renderable
# rate to ~87% on Hi-Robot-style datasets.
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.10
bindings:
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
# Spoken reply only: the assistant turn carries no text content,
# just a `say` tool call (`tool_calls_from: speech`). The chat
# tokenizer flattens it to a `<say>...</say>` marker, so the
# supervised target trains the model to respond to an
# interjection with a spoken acknowledgement.
- {role: assistant, stream: high_level, target: true, if_present: speech, tool_calls_from: speech}
# VQA is view-dependent — each camera gets its own sub-recipe so the
# resolver disambiguates via `camera=...`. Camera keys match
# subtasks_vqa.yaml (`front` + `wrist`); adjust to your dataset.
ask_vqa_top:
weight: 0.075
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.075
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
@@ -1,61 +0,0 @@
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
#
# Trains two things only: subtasks and VQA. Plan and memory are
# intentionally left out — keeps the prompt short and the training
# surface small. The fuller blend with memory + spoken replies is
# ``subtask_mem_vqa_speech.yaml``.
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# PI052's text tokenizer renders these messages as plain
# ``Role: content`` text (PaliGemma is not chat-pretrained).
blend:
high_level_subtask:
weight: 0.40
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.40
messages:
# The action expert is conditioned on the SUBTASK — at inference
# the high-level loop (``HighLevelSubtaskFwd``) generates the
# subtask via the LM head and feeds it here. The action expert's
# prefix is [images, subtask, state]. ``stream: low_level`` flips
# ``predict_actions=True`` so the flow loss fires; no text-CE
# target here (subtask prediction is owned by
# ``high_level_subtask``).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+2 -19
View File
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from . import parser
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
@@ -111,20 +111,9 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
ema: EMAConfig = field(default_factory=EMAConfig)
peft: PeftConfig | None = None
# VQA oversampling. When set (a fraction in (0, 1)), the training
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
# carrying a `vqa` language annotation often enough that they make
# up roughly this fraction of the training stream. VQA annotations
# are typically sparse, so without this they are underrepresented.
# `None` (default) keeps uniform episode-aware sampling.
vqa_target_fraction: float | None = None
# Sample weighting configuration (e.g., for RA-BC training). Old
# inline ``use_rabc`` / ``rabc_*`` params are migrated to this
# field by ``_migrate_legacy_rabc_keys`` above.
# Sample weighting configuration (e.g., for RA-BC training)
sample_weighting: SampleWeightingConfig | None = None
# Rename map for the observation to override the image and state keys
@@ -188,12 +177,6 @@ class TrainPipelineConfig(HubMixin):
)
active_cfg = self.trainable_config
if self.rename_map and active_cfg.pretrained_path is None:
raise ValueError(
"`rename_map` requires a pretrained policy checkpoint. "
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{active_cfg.type}"
+2 -15
View File
@@ -35,6 +35,7 @@ from .dataset_tools import (
remove_feature,
split_dataset,
)
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
@@ -49,24 +50,11 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
def make_dataset(*args, **kwargs):
from .factory import make_dataset as _make_dataset
return _make_dataset(*args, **kwargs)
def resolve_delta_timestamps(*args, **kwargs):
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
return _resolve_delta_timestamps(*args, **kwargs)
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
# and legacy migration constants are intentionally NOT re-exported here.
# Import directly: ``from lerobot.datasets.io_utils import ...``
@@ -77,7 +65,6 @@ __all__ = [
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"WeightedEpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
-43
View File
@@ -126,53 +126,10 @@ class DatasetReader:
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
# Datasets annotated with the PR1 language columns may have been
# written without registering those columns in ``meta/info.json``
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
# back-filled by ``lerobot-annotate``). Probe a single parquet
# shard and graft the column features on so the strict
# ``Dataset.from_parquet`` cast doesn't fail with
# ``column names don't match``.
features = self._extend_features_with_language_columns(features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _extend_features_with_language_columns(
self, features: datasets.Features
) -> datasets.Features:
"""Add ``language_persistent`` / ``language_events`` to ``features``
when the underlying parquet shards declare them but the metadata
doesn't. No-op when neither column is present or both are
already registered.
"""
# Find any one parquet to peek at; bail if there are none yet
# (the dataset will fail later for an unrelated reason and we
# want that error to surface as-is).
try:
sample = next((self.root / "data").glob("*/*.parquet"))
except StopIteration:
return features
from pyarrow import parquet as _pq # noqa: PLC0415
schema_names = set(_pq.read_schema(sample).names)
from .language import ( # noqa: PLC0415
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
language_events_column_feature,
language_persistent_column_feature,
)
extra: dict[str, object] = {}
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
extra[LANGUAGE_EVENTS] = language_events_column_feature()
if not extra:
return features
return datasets.Features({**features, **extra})
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
+3 -89
View File
@@ -170,29 +170,6 @@ def render_sample(
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
# VQA-priority routing. A ``vqa`` annotation is sparse and
# view-dependent; the plain weighted blend would (a) waste a draw
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
# sub-recipes and *this* frame carries one of their VQA bindings,
# render VQA here regardless of the weighted draw. That makes VQA's
# recipe-side training share equal the VQA-annotation density (the
# maximum reachable without a dataset-level oversampling sampler).
if recipe.blend is not None:
vqa_rendered = _render_vqa_if_present(
recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
if vqa_rendered is not None:
return vqa_rendered
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
@@ -206,59 +183,6 @@ def render_sample(
return _render_message_recipe(selected_recipe, bindings)
def _render_vqa_if_present(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> RenderedMessages | None:
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
annotation; otherwise return ``None`` so the caller falls back to the
normal weighted blend.
When several VQA sub-recipes resolve (e.g. a frame annotated for more
than one camera), one is chosen deterministically by relative weight.
"""
assert recipe.blend is not None
renderable: list[tuple[float, RenderedMessages]] = []
for name, component in recipe.blend.items():
if not name.startswith("ask_vqa"):
continue
bindings = _resolve_bindings(
component,
persistent=persistent,
events=events,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
rendered = _render_message_recipe(component, bindings)
if rendered is not None:
renderable.append((float(component.weight or 0.0), rendered))
if not renderable:
return None
if len(renderable) == 1:
return renderable[0][1]
# Multiple cameras have a VQA for this frame — deterministic pick by
# relative weight (fall back to a uniform draw if all weights are 0).
total = sum(w for w, _ in renderable) or float(len(renderable))
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total
cumulative = 0.0
for w, rendered in renderable:
cumulative += w or (total / len(renderable))
if draw < cumulative:
return rendered
return renderable[-1][1]
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
@@ -422,15 +346,7 @@ def _render_message_recipe(
if turn.target:
target_indices.append(message_idx)
# A render is meaningful if it supervises *something*: either a
# text-CE target turn, or a ``low_level`` stream turn (flow / action
# supervision — e.g. the flow-only ``low_level_execution`` recipe,
# ``user(${subtask})`` with ``stream: low_level`` and no target).
# Without this, a flow-only recipe renders to ``None`` every time
# the blend draws it → ``predict_actions`` is never True → the
# action expert never receives a flow loss.
has_low_level = any(stream == "low_level" for stream in streams)
if not target_indices and not has_low_level:
if not target_indices:
return None
rendered = {
@@ -487,10 +403,8 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
# Valid iff it supervises something: a text-CE target turn OR a
# ``low_level`` stream turn (flow / action supervision).
if not target_indices and not any(s == "low_level" for s in streams):
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
+2 -3
View File
@@ -524,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool | None = None,
private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
**card_kwargs,
@@ -543,8 +543,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
tag_version: If ``True``, create a Git tag for the current codebase
version.
push_videos: If ``False``, skip uploading the ``videos/`` directory.
private: If ``True``, create a private repository. If ``None``
(default), defer to the org default on the Hub (only affects orgs).
private: If ``True``, create a private repository.
allow_patterns: Glob pattern(s) restricting which files to upload.
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
of ``upload_folder`` for very large datasets.
-63
View File
@@ -84,66 +84,3 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
proportion to per-frame weights.
Used to oversample frames carrying a sparse annotation (e.g. a VQA
question) so the policy sees them more often than their natural
dataset density. One epoch still yields ``len(self.indices)``
samples — the weights only change the *composition* of the stream,
not its length. Each epoch re-draws, so the oversampled subset
varies run to run.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
frame_weights,
*,
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
):
"""
Args:
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
dataset_to_indices: Episode end indices.
frame_weights: 1-D sequence/tensor of non-negative weights, one per
dataset frame (length == total dataset frames). Higher weight ⇒
that frame is sampled more often.
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
frame filtering is applied first, then weighting is restricted
to the surviving frames.
"""
super().__init__(
dataset_from_indices,
dataset_to_indices,
episode_indices_to_use=episode_indices_to_use,
drop_n_first_frames=drop_n_first_frames,
drop_n_last_frames=drop_n_last_frames,
shuffle=False,
)
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
idx = torch.tensor(self.indices, dtype=torch.long)
if weights.numel() <= int(idx.max()):
raise ValueError(
f"frame_weights has {weights.numel()} entries but the sampler "
f"references frame index {int(idx.max())}."
)
selected = weights[idx]
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
raise ValueError("frame_weights must be finite and non-negative.")
if float(selected.sum()) <= 0.0:
# All surviving frames have zero weight — fall back to uniform.
selected = torch.ones_like(selected)
self._weights = selected
def __iter__(self) -> Iterator[int]:
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
for i in picks.tolist():
yield self.indices[i]
+10 -17
View File
@@ -366,24 +366,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
hub_versions = get_repo_versions(repo_id)
if not hub_versions:
msg = (
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
f"either doesn't exist on the Hub yet, or it was uploaded "
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
f" from huggingface_hub import HfApi\n"
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
# ``__init__`` indexes ``response.headers`` unconditionally on
# current ``huggingface_hub`` versions. Constructing it without
# a real ``Response`` object crashes with either
# ``TypeError: missing 1 required keyword-only argument`` (old
# builds) or ``AttributeError: 'NoneType' object has no attribute
# 'headers'`` (new builds). Skip that path entirely — this isn't
# really an HTTP error, it's a configuration issue — and raise a
# plain ``RuntimeError`` so the message actually reaches the
# caller.
raise RuntimeError(msg)
if target_version in hub_versions:
return f"v{target_version}"
+10 -13
View File
@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
# Dimensions for the flat action/state vectors used by the LeRobot wrapper.
# These correspond to the PandaOmron robot in RoboCasa365.
OBS_STATE_DIM = 16 # ee_pos_rel(3) + ee_quat_rel(4) + base_pos(3) + base_quat(4) + gripper_qpos(2)
ACTION_DIM = 12 # ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
OBS_STATE_DIM = 16 # base_pos(3) + base_quat(4) + ee_pos_rel(3) + ee_quat_rel(4) + gripper_qpos(2)
ACTION_DIM = 12 # base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
@@ -101,15 +101,14 @@ def _resolve_tasks(task: str) -> tuple[list[str], str | None]:
def convert_action(flat_action: np.ndarray) -> dict[str, Any]:
"""Split a flat (12,) action vector into a RoboCasa action dict.
Layout (openpi / robocasa.utils.env_utils.convert_action order):
ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
Layout: base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
"""
return {
"action.end_effector_position": flat_action[0:3],
"action.end_effector_rotation": flat_action[3:6],
"action.gripper_close": flat_action[6:7],
"action.base_motion": flat_action[7:11],
"action.control_mode": flat_action[11:12],
"action.base_motion": flat_action[0:4],
"action.control_mode": flat_action[4:5],
"action.end_effector_position": flat_action[5:8],
"action.end_effector_rotation": flat_action[8:11],
"action.gripper_close": flat_action[11:12],
}
@@ -231,14 +230,12 @@ class RoboCasaEnv(gym.Env):
return {"pixels": images}
# `state.*` keys come from PandaOmronKeyConverter inside the wrapper.
# openpi state order: ee first, then base, then gripper (matches the
# openpi robocasa pipeline / examples/robocasa/main.py state layout).
agent_pos = np.concatenate(
[
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
raw_obs.get("state.base_position", np.zeros(3)),
raw_obs.get("state.base_rotation", np.zeros(4)),
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
raw_obs.get("state.gripper_qpos", np.zeros(2)),
],
axis=-1,
-2
View File
@@ -104,8 +104,6 @@ class AdamWConfig(OptimizerConfig):
eps: float = 1e-8
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
foreach: bool | None = None
fused: bool | None = None
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
kwargs = asdict(self)
-2
View File
@@ -25,7 +25,6 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pi052.configuration_pi052 import PI052Config as PI052Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -50,7 +49,6 @@ __all__ = [
"PI0Config",
"PI0FastConfig",
"PI05Config",
"PI052Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+2 -144
View File
@@ -57,85 +57,11 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
def _restore_pi052_pretrained_state(
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
pretrained_path: str,
) -> None:
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
pi052's preprocessor includes steps whose constructor args don't
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
fitted-tokenizer path that may not exist at eval time). We rebuild
those pipelines fresh from ``config.recipe_path`` and then walk
over the saved ``policy_{pre,post}processor.json`` files to find
each step's ``state_file`` reference and load the bytes back into
the corresponding fresh step. Today that's only the
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
state quantile stats), but the loop is generic so any future
stateful step picks up its blob automatically.
Pairing is by ``registry_name`` AND position so a benign reorder
on the saved side surfaces a warning rather than silently feeding
the wrong tensors into the wrong step.
"""
import json # noqa: PLC0415
import logging # noqa: PLC0415
from pathlib import Path # noqa: PLC0415
from safetensors.torch import load_file # noqa: PLC0415
base = Path(pretrained_path)
if not base.exists():
return
log = logging.getLogger(__name__)
for pipeline, config_filename in [
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
]:
config_path = base / config_filename
if not config_path.exists():
continue
saved = json.loads(config_path.read_text())
for idx, (saved_step, fresh_step) in enumerate(
zip(saved.get("steps", []), pipeline.steps, strict=False)
):
state_file = saved_step.get("state_file")
if not state_file:
continue
saved_name = saved_step.get("registry_name")
fresh_name = getattr(type(fresh_step), "_registry_name", None)
if saved_name and fresh_name and saved_name != fresh_name:
log.warning(
"PI052 state restore: %s step %d registry name mismatch "
"(saved=%s, fresh=%s); skipping %s",
config_filename, idx, saved_name, fresh_name, state_file,
)
continue
state_path = base / state_file
if not state_path.exists():
log.warning(
"PI052 state restore: %s missing at %s; %s left at fresh init",
state_file, base, fresh_name,
)
continue
fresh_step.load_state_dict(load_file(str(state_path)))
log.info(
"PI052 state restore: loaded %s into %s (step %d)",
state_file, fresh_name, idx,
)
def _reconnect_relative_absolute_steps(
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
) -> None:
@@ -203,10 +129,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "pi052":
from .pi052.modeling_pi052 import PI052Policy
return PI052Policy
elif name == "gaussian_actor":
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
@@ -235,10 +157,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
return MolmoAct2Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -255,8 +173,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
"pi052", "gaussian_actor", "smolvla", "wall_x", "molmoact2".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x", "molmoact2".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -279,10 +197,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "pi052":
from .pi052.configuration_pi052 import PI052Config
return PI052Config(**kwargs)
elif policy_type == "gaussian_actor":
return GaussianActorConfig(**kwargs)
elif policy_type == "smolvla":
@@ -297,8 +211,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return EO1Config(**kwargs)
elif policy_type == "molmoact2":
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -327,12 +239,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
# Optional: HF Hub repo id of the dataset the policy is being
# trained on. Used by policies that auto-fit pieces of their
# preprocessing (e.g. pi052's FAST action tokenizer per
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
# policies fall back to their universal pre-fitted tokenizers.
dataset_repo_id: str | None
dataset_meta: Any | None
@@ -366,29 +272,6 @@ def make_pre_post_processors(
NotImplementedError: If a processor factory is not implemented for the given
policy configuration type.
"""
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
# pi052 pipelines don't roundtrip through the saved
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
# FAST tokenizer path. Generic ``from_pretrained`` then dies
# with ``RenderMessagesStep.__init__() missing 1 required
# positional argument: 'recipe'`` (job 22164494).
#
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
# fresh from ``config.recipe_path`` and transplant the saved
# stateful blobs (normalizer stats) from the checkpoint dir.
from .pi052.processor_pi052 import make_pi052_pre_post_processors
preprocessor, postprocessor = make_pi052_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_repo_id=kwargs.get("dataset_repo_id"),
)
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
@@ -483,22 +366,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif policy_cfg.type == "pi052":
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
# come before the PI05Config isinstance check below (otherwise
# pi052 would silently pick up π0.5's processor).
from .pi052.processor_pi052 import make_pi052_pre_post_processors
processors = make_pi052_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
# enabled — the train loop sets it from ``--dataset.repo_id``.
# When ``None``, ``make_pi052_pre_post_processors`` skips
# the auto-fit and uses the universal tokenizer.
dataset_repo_id=kwargs.get("dataset_repo_id"),
)
elif isinstance(policy_cfg, PI05Config):
from .pi05.processor_pi05 import make_pi05_pre_post_processors
@@ -548,7 +415,6 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -566,14 +432,6 @@ def make_pre_post_processors(
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
+1
View File
@@ -178,6 +178,7 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
-42
View File
@@ -1,42 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
inference recipe on lerobot.
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
(the "high-level subtask prediction" of the paper, §IV.D),
* AR text generation at inference (:meth:`PI052Policy.select_message`),
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
text head against missing context at inference.
See ``src/lerobot/configs/recipes/subtasks_vqa.yaml`` for the
canonical training recipe and
``examples/training/pi052_hirobot.slurm`` for the launcher.
"""
from .configuration_pi052 import PI052Config
from .modeling_pi052 import PI052Policy
from .processor_pi052 import make_pi052_pre_post_processors
from .text_processor_pi052 import PI052TextTokenizerStep
__all__ = [
"PI052Config",
"PI052Policy",
"PI052TextTokenizerStep",
"make_pi052_pre_post_processors",
]
@@ -1,235 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
hierarchical inference recipe.
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
~300M Gemma action expert, joint training with FAST tokens during
pre-train and flow matching during post-train), but with the
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
to predict both:
* **subtask strings** at the high level (cross-entropy on the LM
head), and
* **action chunks** at the low level (flow matching on the
action-expert tokens).
This is the dual-head co-training pattern from the paper:
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, )‖²
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
inference into a text-prediction step followed by an action-prediction
step, which the multi-rate ``PI052Runtime`` (in
``lerobot.policies.pi052.inference``) drives at separate rates.
"""
from dataclasses import dataclass
from lerobot.configs import PreTrainedConfig
from lerobot.optim.optimizers import AdamWConfig
from ..pi05.configuration_pi05 import PI05Config
@PreTrainedConfig.register_subclass("pi052")
@dataclass
class PI052Config(PI05Config):
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
Recipe-driven dual-head training: the flow head supervises actions,
the LM head supervises subtask / plan / memory / VQA text. The
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
"""
# Recipe / language stack ---------------------------------------------
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
shipped alongside this policy. Set to ``None`` to disable recipe
rendering and fall back to π0.5's single-task ``Task: ... Action:``
prompt path (unannotated datasets keep working that way)."""
apply_chat_template: bool = False
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
chat template, so we don't apply one. The recipe renderer's output
is concatenated as a plain prefix + assistant suffix instead,
mirroring how the π0.5 paper's high-level inference samples text
auto-regressively after the prefix."""
# Loss weights --------------------------------------------------------
# Paper §IV.D uses α=10 between the flow and text terms, assuming
# text is a rare auxiliary task. With the recipe stack the flow-only
# `low_level` branch fires on a large share of samples, so α=10
# swamps the LM head and collapses generation into degenerate
# repetition. We use the milder 5:1 split here.
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / π0.5 behaviour)."""
flow_loss_weight: float = 5.0
"""Weight on the action-expert flow-matching term. ``5.0`` — a milder
flow:text split than the paper's α=10, since the flow-only
``low_level`` recipe already gives the action expert frequent
gradient. Lower it further if the LM head still underfits."""
# Backbone training ---------------------------------------------------
unfreeze_lm_head: bool = True
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
The existing ``PI05Policy`` zeroes / freezes the head on load
because it never reads from it. Must be ``True`` for π0.5-style
hierarchical inference."""
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
# Randomly drop non-target context messages so the LM head learns
# to handle missing /
# stale plan / memory at inference. Defaults to 0.0 so behaviour
# is identical until explicitly enabled.
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
# FAST discrete-action supervision — paper §III.B-C ------------------
# When enabled, actions are *also* tokenised via the FAST tokenizer
# ("physical-intelligence/fast") and supervised with cross-entropy
# on the PaliGemma LM head — exactly as in the paper's pre-training
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
# ActionTokenizerProcessorStep is wired into the preprocessor
# pipeline when this flag is set; the loss is computed in
# PI052Policy.forward.
enable_fast_action_loss: bool = True
"""If True, tokenise actions with the FAST tokenizer and add a
cross-entropy loss on the LM head. On by default to match the
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
§III.B-C Eq. 1). Set to False if you only want the
post-training-style flow + text recipe."""
action_tokenizer_name: str = "physical-intelligence/fast"
"""HF identifier for the FAST action tokenizer."""
max_action_tokens: int = 256
"""Maximum number of FAST tokens per action chunk."""
fast_skip_tokens: int = 128
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
collisions with PaliGemma's text vocabulary."""
fast_action_loss_weight: float = 1.0
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
auto_fit_fast_tokenizer: bool = False
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
base_tokenizer_name, fit_samples)``. On cache miss, it loads
``action_tokenizer_name`` as a base, samples
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
``.fit()``, saves the result, and uses *that* fitted path as the
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
explicitly recommend per-dataset fitting for best compression.
Off by default because the fit requires a separate pre-training
pass over the dataset (~1-2 min on a medium dataset) and depends
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
when you want paper-faithful compression; leave off to fall back
on the universal ``physical-intelligence/fast`` codebook."""
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
fast_tokenizer_fit_samples: int = 1024
"""Number of action chunks to sample for the fit. The FAST paper uses
a few thousand; 1024 is a reasonable default for medium datasets."""
# Knowledge insulation — paper §III.B --------------------------------
# When enabled, gradients from the action expert's flow loss are
# blocked from flowing back into the VLM's K/V projections. This
# prevents the action loss from over-fitting the language backbone
# to robot-specific features. Implemented in ``modeling_pi052`` as
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
# that splits queries into VLM and action halves and ``.detach()``-s
# the VLM K/V tensors used in the action-half's attention.
knowledge_insulation: bool = False
"""If True, route every transformer layer through the KI
attention path that blocks action→VLM gradient flow on K/V."""
# Learning-rate defaults --------------------------------------------
# pi052 inherits π0.5's openpi-validated optimizer config (peak LR
# 2.5e-5, cosine→2.5e-6, 1k warmup, AdamW (0.9, 0.95), wd=0.01,
# grad_clip=1.0). The only place pi052 needs to diverge from pi05
# is the LM-head LR multiplier: pi05 has no text supervision so the
# head doesn't get gradients; pi052 always has text supervision
# (subtask / memory / VQA) via the recipe, and under KI the LM head
# only sees gradients on ~3045% of the batch (the text-CE mask
# share of the recipe). Under aggressive cosine decay this is too
# weak to keep the head pinned, so it drifts back toward PaliGemma's
# pretrained ``<loc>`` first-token bias. 5x is the documented fix
# (see ``PI05Config.lm_head_lr_scale`` docstring); the wiring is
# already in ``PI05Policy.get_optim_params`` — it splits the LM head
# + tied ``embed_tokens`` into their own param group while sharing
# the same cosine lambda, so the 5x ratio is preserved across decay.
lm_head_lr_scale: float = 5.0
# PaLM-style z-loss on text CE. Penalises the log-partition function
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
# while CE stays low, because a uniform additive logit bias cancels in
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
# commonly cited weight; set 0 to disable entirely.
text_ce_z_loss_weight: float = 1e-4
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
# unconditionally at model build time — see ``_enable_hf_kernels``
# in ``modeling_pi052``. The patch is process-global, idempotent
# and degrades gracefully if ``liger-kernel`` is missing. Measured
# at -4.5% step time on H100 (bench job 22161421); peak memory
# unchanged. ``fused_linear_cross_entropy`` ships separately via
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
use_hf_kernels: bool = True
"""Deprecated. Liger HF kernels are patched unconditionally by
``_enable_hf_kernels`` — this field is retained as a no-op for
backward compatibility with checkpoints saved before commit
d70c8104 (which still serialize ``use_hf_kernels: true`` into
``config.json``). Loading those configs would otherwise raise
``DecodingError: The fields use_hf_kernels are not valid for
PI052Config`` (job 22164492). Remove in a future major bump."""
# Optimizer foreach/fused. pi052 carries these locally because the shared
# PI05Config (kept identical to upstream main) does not define them; the
# checkpoints we train serialize both keys into config.json, so they must
# be valid PI052Config fields and flow into the AdamW preset below.
optimizer_foreach: bool | None = False
optimizer_fused: bool | None = True
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
foreach=self.optimizer_foreach,
fused=self.optimizer_fused,
)
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through the text head when
# we're training it. Override the π0.5 default
# (``train_expert_only=True``) unless the user explicitly opts
# out of text training via ``text_loss_weight=0``.
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
self.train_expert_only = False
@@ -1,304 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset-specific FAST action tokenizer fitting.
The published ``physical-intelligence/fast`` tokenizer is a *universal*
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
π0.5 itself, the recommended practice is to **finetune the tokenizer on
your specific dataset's action distribution** before training the
policy — same way one would adapt a language tokenizer to a domain
corpus. Without this finetune step, action sequences from your robot
may require more tokens per chunk than necessary, lowering effective
compression and slowing convergence of the action-CE loss.
This module provides a single utility, :func:`fit_fast_tokenizer`,
that does the finetune. The training entry point invokes it
automatically when the policy's ``enable_fast_action_loss`` and
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
The fitted tokenizer is saved to
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
runs over the same dataset re-use it.
"""
from __future__ import annotations
import hashlib
import logging
import os
import time
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
# that's the image / feature-extractor convention). Centralised here so
# the cache-hit check and the rank-N readiness wait agree on the same
# sentinel.
_CACHE_SENTINEL = "processor_config.json"
def _dataset_signature(
dataset_repo_id: str,
base_tokenizer_name: str,
n_samples: int,
chunk_size: int,
) -> str:
"""Deterministic short hash for naming the cache directory.
Keys on (dataset, base tokenizer, sample count, chunk size) so any
of those changing re-runs the fit. ``chunk_size`` matters because
the tokenizer is fit on chunks of that length.
"""
h = hashlib.sha256()
h.update(dataset_repo_id.encode("utf-8"))
h.update(b"\0")
h.update(base_tokenizer_name.encode("utf-8"))
h.update(b"\0")
h.update(str(n_samples).encode("utf-8"))
h.update(b"\0")
h.update(str(chunk_size).encode("utf-8"))
return h.hexdigest()[:16]
def fit_fast_tokenizer(
*,
dataset_repo_id: str,
cache_dir: str | Path,
base_tokenizer_name: str = "physical-intelligence/fast",
n_samples: int = 1024,
chunk_size: int = 50,
seed: int = 42,
) -> str:
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
Args:
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
cache_dir: Directory under which to save (and look up) fitted
tokenizers. The actual save path is
``{cache_dir}/{signature}``.
base_tokenizer_name: HF identifier for the base FAST tokenizer
to finetune from. ``physical-intelligence/fast`` is the
universal one.
n_samples: Number of action chunks to sample for the fit. The
FAST paper uses a few thousand; ``1024`` is a good default
for medium datasets.
chunk_size: Length of each action chunk (matches
``policy.chunk_size``). The FAST tokenizer is fit on
sequences of this length.
seed: RNG seed for sample selection.
Returns:
The local path to the fitted tokenizer. Passed directly to
``--policy.action_tokenizer_name`` for the training run.
Raises:
ImportError: If the ``transformers`` library doesn't expose
``AutoProcessor`` or the FAST tokenizer doesn't have a
``.fit()`` method (then you're on an older FAST snapshot —
update to the current published model).
FileNotFoundError: If the dataset can't be loaded.
"""
cache_dir = Path(cache_dir)
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
out_dir = cache_dir / sig
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
logger.info(
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
"dataset=%s base=%s n_samples=%d",
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
)
return str(out_dir)
# DDP-safe fit: only the (local) main process actually fits + saves;
# other ranks poll the cache sentinel until the leader is done.
# Without this guard, all N ranks fit concurrently and race on
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
# and compiles a ``.pyc`` — concurrent writers occasionally produce
# a stale / partial ``.pyc`` and the subsequent ``from .. import
# UniversalActionProcessor`` raises ``AttributeError``.
is_leader = (
int(os.environ.get("RANK", "0")) == 0
and int(os.environ.get("LOCAL_RANK", "0")) == 0
)
if not is_leader:
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
start = time.monotonic()
while not (out_dir / _CACHE_SENTINEL).exists():
if time.monotonic() - start > timeout_s:
raise RuntimeError(
f"FAST tokenizer fit: non-leader rank timed out after "
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
"Leader rank likely crashed during the fit."
)
time.sleep(2.0)
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
return str(out_dir)
logger.info(
"FAST tokenizer cache miss — fitting on dataset=%s "
"base=%s n_samples=%d chunk_size=%d%s",
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
)
from transformers import AutoProcessor # noqa: PLC0415
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
# Stream a single episode's worth of action chunks at a time so
# we don't blow memory on huge datasets. Random episode +
# random start offset gives a reasonable spread.
#
# Actions are read straight from the underlying HF dataset's
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
# training item (delta-timestamp expansion + video decode + image
# transforms); a single bad video frame would then throw and, since
# the failure was swallowed at debug level, silently starve the fit
# of every chunk. The action column carries no video, so reading it
# directly is both faster and immune to decode errors.
rng = np.random.default_rng(seed)
actions_buf: list[np.ndarray] = []
# Resolve the dataset's data parquet shards directly, sidestepping
# ``LeRobotDataset(repo_id, episodes=[N])`` which on v3-format
# datasets routes through HF datasets'' split lookup and raises
# ``ValueError: Instruction "train" corresponds to no data!`` for
# every episode (job 22182985 looped through 13,293 skipped episodes
# for ~2.5 h before NCCL killed it). Reading the ``action`` column
# straight from the parquet shards is also faster: each per-episode
# ``LeRobotDataset`` instantiation re-parses every meta file.
from huggingface_hub import snapshot_download # noqa: PLC0415
import pyarrow as _pa # noqa: PLC0415
import pyarrow.parquet as _pq # noqa: PLC0415
snap = Path(snapshot_download(repo_id=dataset_repo_id, repo_type="dataset"))
data_files = sorted((snap / "data").glob("chunk-*/file-*.parquet"))
if not data_files:
raise RuntimeError(
f"FAST fit: no ``data/chunk-*/file-*.parquet`` shards found under {snap!s}."
)
# Read just the (episode_index, action) columns once across all
# shards. This is the same pattern used elsewhere in the codebase
# for whole-dataset audits and stays under ~2 GB even on 32 k-episode
# / 29 M-frame datasets because the action column is a fixed-length
# float vector.
tables = [_pq.read_table(f, columns=["episode_index", "action"]) for f in data_files]
table = _pa.concat_tables(tables)
eps = table["episode_index"].to_numpy()
acts_col = table["action"]
# ``action`` may be a fixed-shape ListArray or a 2-D NumericArray;
# ``to_numpy(zero_copy_only=False)`` produces an object array of
# 1-D NumPy actions either way, which we stack into (N, D).
try:
acts = np.stack(acts_col.to_numpy(zero_copy_only=False)).astype(np.float32)
except Exception: # noqa: BLE001
# Fallback path for nested-list types: flatten via to_pylist().
acts = np.asarray(acts_col.to_pylist(), dtype=np.float32)
if acts.ndim != 2:
raise RuntimeError(
f"FAST fit: expected ``action`` rows to be 1-D vectors; got shape {acts.shape}."
)
# Episode index → slice (start, stop) into ``acts`` along axis 0.
# ``eps`` is monotonically increasing within each parquet shard but
# we make no assumption across shards — sort once and group.
order = np.argsort(eps, kind="stable")
eps_sorted = eps[order]
boundaries = np.searchsorted(eps_sorted, np.arange(int(eps_sorted.max()) + 2))
ep_to_slice: dict[int, tuple[int, int]] = {
int(ep): (int(boundaries[ep]), int(boundaries[ep + 1]))
for ep in range(len(boundaries) - 1)
if boundaries[ep] < boundaries[ep + 1]
}
num_episodes = len(ep_to_slice)
# ``acts`` is in original (un-sorted-by-episode) row order; reorder
# so per-episode slices are contiguous.
acts = acts[order]
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
collected = 0
eps_visited = 0
short_episodes = 0
ep_indices = list(ep_to_slice.keys())
for ep_idx in rng.permutation(ep_indices):
if collected >= n_samples:
break
start, stop = ep_to_slice[int(ep_idx)]
ep_actions = acts[start:stop]
if ep_actions.shape[0] < chunk_size:
short_episodes += 1
continue
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
for s in starts:
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
collected += 1
if collected >= n_samples:
break
eps_visited += 1
if not actions_buf:
raise RuntimeError(
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
f"all {num_episodes} episodes were shorter than chunk_size="
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
"``action`` column. Lower ``chunk_size`` to match your episode "
"lengths."
)
actions = np.stack(actions_buf, axis=0).astype(np.float32) # (N, H, D)
logger.info(
"FAST fit: collected %d chunks of shape %s from %d episodes",
actions.shape[0], actions.shape[1:], eps_visited,
)
# Quantile-normalise per dimension before fitting.
#
# The FAST tokenizer DCT-transforms actions, scales by ``scale`` and
# rounds to integer tokens; the integer *range* must fit the
# codebook (vocab_size, default 1024). Raw motor units (e.g. encoder
# ticks) blow that range up — hence "Vocab size 1024 is too small".
# More importantly, at training time ``ActionTokenizerProcessorStep``
# runs *after* the QUANTILES ``NormalizerProcessorStep``, so it
# encodes normalised actions. Fitting on raw actions would mismatch
# that space. We replicate QUANTILES normalisation here (per-dim
# [q01, q99] → [-1, 1], clipped) so the fit and the training-time
# encode see the same distribution.
flat = actions.reshape(-1, actions.shape[-1])
q01 = np.quantile(flat, 0.01, axis=0)
q99 = np.quantile(flat, 0.99, axis=0)
span = np.where((q99 - q01) > 1e-6, q99 - q01, 1.0)
actions = np.clip((actions - q01) / span * 2.0 - 1.0, -1.0, 1.0).astype(np.float32)
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
if not hasattr(base, "fit"):
raise ImportError(
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
"method — your transformers / model snapshot is too old. Update "
"to the current ``physical-intelligence/fast`` revision."
)
fitted = base.fit(actions)
out_dir.mkdir(parents=True, exist_ok=True)
fitted.save_pretrained(str(out_dir))
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
return str(out_dir)
@@ -1,73 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PI052 inference / runtime orchestration.
Multi-rate runtime that mirrors the recipe-time training shape:
low_level_execution → LowLevelForward + DispatchAction (high Hz)
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
memory_update → MemoryUpdateFwd (event: subtask_change)
user_interjection_response → UserInterjectionFwd (event: stdin)
ask_vqa_* → AskVQAFwd (event: stdin question)
speech tool calls → DispatchToolCalls (event: tool_call_pending)
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
``run()``.
"""
from .repl import StdinReader
from .runtime import PI052Runtime
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
UserInterjectionFwd,
)
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
from .ui import make_state_panel, print_robot_lines, print_user_line
__all__ = [
# runtime
"PI052Runtime",
"StdinReader",
# state helpers
"initial_runtime_state",
"push_log",
"set_if_changed",
"take_event",
# triggers
"Trigger",
"Tick",
"TickClock",
"HzTrigger",
"EventTrigger",
# steps
"InferenceStep",
"LowLevelForward",
"DispatchAction",
"HighLevelSubtaskFwd",
"MemoryUpdateFwd",
"UserInterjectionFwd",
"AskVQAFwd",
"DispatchToolCalls",
# UI
"make_state_panel",
"print_robot_lines",
"print_user_line",
]
@@ -1,105 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stdin REPL event collector for the PI052 runtime.
Reads non-blocking stdin lines, classifies each one heuristically:
"stop" / "quit" / "exit" → state["stop"] = True
"/action" / "/pause" → set state["mode"]
ends with "?" → user_vqa_query event
starts with "task:" or first line → set runtime task
anything else → user_interjection event
Plugged into the runtime via ``event_collector=StdinReader().poll``.
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
directly in its REPL / autonomous loops and does *not* wire this
collector; it's kept as the documented embedding hook and for tests.
"""
from __future__ import annotations
import select
import sys
from dataclasses import dataclass, field
from typing import Any
@dataclass
class StdinReader:
"""Non-blocking stdin line collector for the runtime loop."""
prompt: str = "> "
_seen_first_line: bool = field(default=False, init=False)
_prompted: bool = field(default=False, init=False)
def poll(self, state: dict[str, Any]) -> None:
"""Drain pending stdin lines into runtime events."""
# Print the input prompt once on every fresh tick if we don't
# already have a pending line; matches the expected REPL feel.
if not self._prompted:
print(self.prompt, end="", flush=True)
self._prompted = True
# ``select`` with timeout=0 makes this non-blocking. Only works
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
try:
ready, _, _ = select.select([sys.stdin], [], [], 0)
except (ValueError, OSError):
return
if not ready:
return
line = sys.stdin.readline()
if not line: # EOF
state["stop"] = True
return
line = line.strip()
self._prompted = False # we'll re-prompt next tick
if not line:
return
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
state["stop"] = True
return
# Slash commands flip the run mode. ``/pause`` stops the action
# loop (the action steps gate on ``state["mode"]``); ``/action``
# resumes it.
if lower.split(" ", 1)[0] in {"/action", "/act", "/run"}:
state["mode"] = "action"
return
if lower in {"/pause", "/p"}:
state["mode"] = "paused"
queue = state.get("action_queue")
if hasattr(queue, "clear"):
queue.clear()
return
# First non-control line sets the task if no task is active.
if not state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
state["task"] = task
print(f"[pi052] Task: {task}", flush=True)
self._seen_first_line = True
return
# Question → VQA; statement → interjection.
if lower.endswith("?"):
state["recent_vqa_query"] = line
state.setdefault("events_this_tick", []).append("user_vqa_query")
else:
state["recent_interjection"] = line
state.setdefault("events_this_tick", []).append("user_interjection")
@@ -1,205 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PI052 runtime loop.
Threads the multi-rate inference pipeline together with a stdin REPL
event collector, drives ticks through :class:`TickClock`, and prints
state-change updates to the user.
"""
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable
from .runtime_state import initial_runtime_state, push_log
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
)
from .triggers import EventTrigger, HzTrigger, TickClock
logger = logging.getLogger(__name__)
@dataclass
class PI052Runtime:
"""Compose the inference pipeline and drive it tick-by-tick."""
policy: Any
tools: dict[str, Any] = field(default_factory=dict)
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
from :func:`lerobot.tools.get_tools(meta)` when wiring the
runtime."""
observation_provider: Callable[[], dict | None] | None = None
"""Closure returning the current preprocessed observation batch.
``None`` for dry-run / language-only sessions."""
robot_executor: Callable[[Any], None] | None = None
"""Closure that takes one action chunk and forwards it to the
robot. ``None`` for dry-run."""
event_collector: Callable[[dict], None] | None = None
"""Per-tick hook that polls external sources (stdin, network) and
appends event names to ``state["events_this_tick"]``."""
chunk_hz: float = 4.0
ctrl_hz: float = 50.0
high_level_hz: float = 1.0
max_rate_hz: float = 50.0
pipeline: list[InferenceStep] = field(init=False)
state: dict[str, Any] = field(init=False)
_stop: bool = field(default=False, init=False)
def __post_init__(self) -> None:
# Subtask + memory + VQA configuration. Pipeline:
#
# HighLevelSubtaskFwd → generate the next subtask via the LM
# head at ~``high_level_hz``; writes
# ``current_subtask`` and emits
# ``subtask_change`` on a transition.
# MemoryUpdateFwd → on ``subtask_change``, refresh
# ``current_memory`` from the
# ``memory_update`` head.
# AskVQAFwd → answer camera-grounded stdin questions.
# LowLevelForward → action chunk conditioned on the
# generated ``current_subtask``.
# DispatchAction → drain the chunk to the robot.
# DispatchToolCalls → fire any pending tool calls.
#
# Order matters: ``HighLevelSubtaskFwd`` must run before
# ``MemoryUpdateFwd`` so the event is visible the same tick, and
# both must run before ``LowLevelForward`` (which is gated on
# "action queue empty") so the chunk consumes the freshest
# subtask. ``UserInterjectionFwd`` is still importable but
# disabled until plan generation is wired in.
self.pipeline = [
HighLevelSubtaskFwd(
trigger=HzTrigger(self.high_level_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
# Listens for the ``subtask_change`` event raised by
# ``HighLevelSubtaskFwd`` and refreshes ``current_memory``.
MemoryUpdateFwd(
trigger=EventTrigger("subtask_change"),
policy=self.policy,
observation_provider=self.observation_provider,
),
AskVQAFwd(
policy=self.policy,
observation_provider=self.observation_provider,
),
LowLevelForward(
trigger=HzTrigger(self.chunk_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
DispatchAction(
trigger=HzTrigger(self.ctrl_hz),
robot_executor=self.robot_executor,
),
DispatchToolCalls(tools=self.tools),
]
self.state = initial_runtime_state()
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def set_task(self, task: str) -> None:
"""Set or replace the active task. Logged for the REPL."""
self.state["task"] = task
push_log(self.state, f"Task: {task}")
def stop(self) -> None:
self._stop = True
def run(self, *, max_ticks: int | None = None) -> None:
"""Main loop. Returns when ``stop()`` is called or after
``max_ticks`` ticks (useful for tests / dry-run)."""
clock = TickClock(max_rate_hz=self.max_rate_hz)
while not self._stop:
tick = clock.advance()
self.state["_tick"] = tick
self.state["events_this_tick"] = []
self.state["log_lines"] = []
if self.event_collector is not None:
self.event_collector(self.state)
if self.state.get("stop"):
self._stop = True
break
for step in self.pipeline:
self.state = step(self.state)
self._flush_logs()
if max_ticks is not None and tick.index >= max_ticks:
break
self._on_shutdown()
# ------------------------------------------------------------------
# REPL helper: drive one full pipeline pass and return its logs
# ------------------------------------------------------------------
def step_once(self) -> list[str]:
"""Run one tick of the pipeline and return the log lines.
Used by the interactive REPL: instead of a background thread,
the CLI drives ticks synchronously after each user input. Logs
are returned (not printed) so the caller can route them into
the rich-Live chat scrollback.
"""
from .triggers import Tick # noqa: PLC0415
# Synthesize a tick. We don't need the real wall-clock pacing
# here — the REPL drives the runtime, not vice versa — but
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
# bump it generously so every Hz-triggered step considers
# itself due.
import time as _time # noqa: PLC0415
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
self.state["log_lines"] = []
# ``events_this_tick`` is set up by the caller before
# ``step_once`` (the REPL pushes user-driven events first).
self.state.setdefault("events_this_tick", [])
for step in self.pipeline:
self.state = step(self.state)
return list(self.state.get("log_lines") or [])
# ------------------------------------------------------------------
# I/O
# ------------------------------------------------------------------
def _flush_logs(self) -> None:
for line in self.state.get("log_lines") or []:
print(f"[pi052] {line}", flush=True)
def _on_shutdown(self) -> None:
# Drain any queued action chunks safely.
queue = self.state.get("action_queue")
if isinstance(queue, deque):
queue.clear()
print("[pi052] runtime stopped", flush=True)
@@ -1,95 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runtime state passed between inference steps each tick.
The runtime threads a single dict through the pipeline; this module
documents the shape and provides factories. We use a plain ``dict``
rather than a frozen dataclass because steps freely add and remove
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
) and dataclass field churn would just get in the way.
Stable keys (read by multiple steps):
task str the current top-level task
current_plan str | None latest plan emitted by the planner
current_subtask str | None latest subtask the policy is executing
current_memory str | None latest compressed memory
recent_interjection str | None most recent user interjection text (consumed)
action_queue collections.deque[Tensor] pending action chunks
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
events_this_tick list[str] triggers consumed this tick
_tick Tick current tick (set by the loop)
mode str "action" (run the robot) | "paused"
(action loop stopped robot holds)
log_lines list[str] human-readable status lines printed each tick
"""
from __future__ import annotations
from collections import deque
from typing import Any
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
"""Build a fresh runtime state dict with sensible defaults."""
return {
"task": task,
"current_plan": None,
"current_subtask": None,
"current_memory": None,
"recent_interjection": None,
"action_queue": deque(),
"tool_calls_pending": [],
"events_this_tick": [],
"log_lines": [],
"mode": "action",
"stop": False,
}
def take_event(state: dict[str, Any], event_name: str) -> bool:
"""Pop ``event_name`` from ``events_this_tick`` if present.
Steps that consume an event call this so the same event doesn't
re-fire on a sibling step within the same tick.
"""
events: list[str] = state.get("events_this_tick") or []
if event_name in events:
events.remove(event_name)
return True
return False
def push_log(state: dict[str, Any], line: str) -> None:
"""Append ``line`` to the per-tick log buffer; the runtime prints
it at the end of the tick."""
state.setdefault("log_lines", []).append(line)
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
"""Update ``state[key]`` and log a diff line if the value changed.
Returns ``True`` if the value actually changed.
"""
prev = state.get(key)
if prev == value:
return False
state[key] = value
if label is not None:
push_log(state, f" {label}: {value}")
return True
@@ -1,955 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference steps for the PI052 multi-rate runtime.
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
the runtime applies them in order each tick. When a step's trigger
doesn't fire, the step is a no-op and the runtime moves on.
Stream-to-step mapping mirrors the ``subtasks_vqa.yaml`` recipe:
* ``LowLevelForward`` calls ``policy.select_action`` for the
action chunk; trained by
``low_level_execution``
* ``EnqueueChunk`` pushes the chunk to ``action_queue``
* ``DispatchAction`` pops one action per control tick and
forwards to the robot
* ``HighLevelSubtaskFwd`` calls ``policy.select_message`` for the
next subtask; trained by
``high_level_subtask``
* ``MemoryUpdateFwd`` fires on subtask boundary; trained by
``memory_update``
* ``UserInterjectionFwd`` fires on stdin interjection; trained by
``user_interjection_response``
* ``AskVQAFwd`` fires on stdin question; trained by
``ask_vqa_*``
* ``DispatchToolCalls`` pops ``tool_calls_pending`` and calls
the matching ``Tool`` instance
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from .runtime_state import push_log, set_if_changed, take_event
from .triggers import EventTrigger, HzTrigger, Trigger
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Step base + runner
# ---------------------------------------------------------------------------
@dataclass
class InferenceStep:
"""A trigger-gated callable. Subclasses override :meth:`run`."""
trigger: Trigger
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
if not self.trigger.should_fire(state["_tick"], state):
return state
return self.run(state) or state
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
raise NotImplementedError
# ---------------------------------------------------------------------------
# Low-level (action) path
# ---------------------------------------------------------------------------
@dataclass
class LowLevelForward(InferenceStep):
"""Run the policy's action head and produce one action chunk."""
policy: Any = None
observation_provider: Any = None
"""Callable ``() -> dict``: returns the current observation batch
(already preprocessed). Typically wraps the robot's camera /
proprio reads. ``None`` in dry-run mode step skips."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or self.observation_provider is None:
return None
# ``/vlm`` mode pauses the whole action loop so the robot holds
# position while the operator probes the VLM with VQA.
if state.get("mode", "action") != "action":
return None
if not state.get("task"):
return None
# PI052 produces *action chunks* (typically 50 steps via
# flow-matching). Every step gets dispatched to the robot;
# popping one per dispatch tick is essentially free. Only
# generate a new chunk once the previous one has fully
# drained — this is the canonical "sense → think → act"
# loop. Refreshing while a chunk is still queued causes the
# new chunk to "telescope" past the old one (planned from an
# observation that's already 25+ steps stale by the time it
# starts dispatching).
queue = state.setdefault("action_queue", [])
if len(queue) > 0:
return None
observation = self.observation_provider()
if observation is None:
return None
# The action expert is conditioned on the SUBTASK generated by
# the high-level loop (``HighLevelSubtaskFwd`` runs earlier in
# the pipeline and writes ``current_subtask``). Matches the
# training-time ``low_level_execution`` recipe — ``user(${subtask})``.
# Falls back to the task string only on the very first frame,
# before the high-level loop has produced a subtask.
subtask = state.get("current_subtask") or state.get("task") or ""
ctx = [{"role": "user", "content": subtask}]
# ``add_generation_prompt=False`` to match the training-time
# prefix shape: at training the action expert sees the rendered
# user turn ending at ``<|im_end|>`` (no trailing
# ``<|im_start|>assistant\n``). Passing True here would append
# extra role-marker tokens the action expert never saw during
# training.
text_batch = _build_text_batch(self.policy, ctx, add_generation_prompt=False)
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
observation = dict(observation)
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
try:
# ``predict_action_chunk`` returns the *full* chunk shape
# ``(batch, n_action_steps, action_dim)``. Enqueue every
# step so DispatchAction at ctrl_hz can drain them
# smoothly until the next refresh.
chunk = self.policy.predict_action_chunk(observation)
except Exception as exc: # noqa: BLE001
logger.warning(
"predict_action_chunk failed: %s",
exc,
exc_info=logger.isEnabledFor(logging.DEBUG),
)
push_log(
state,
f" [warn] predict_action_chunk failed: "
f"{type(exc).__name__}: {exc}",
)
return None
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
# each step as a ``(1, action_dim)`` tensor so the existing
# action executor's batch-squeeze logic works unchanged.
if chunk.ndim == 3:
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
elif chunk.ndim == 2:
chunk_iter = chunk
else:
chunk_iter = chunk.unsqueeze(0)
for step in chunk_iter:
queue.append(step.unsqueeze(0))
state["last_chunk_size"] = int(chunk_iter.shape[0])
return None
@dataclass
class DispatchAction(InferenceStep):
"""Pop one action per tick and hand it to the robot.
In dry-run mode (``robot_executor=None``) the step still pops the
queue so it doesn't grow unbounded — the popped tensor is logged
instead of executed.
Wall-clock catch-up: the action queue represents an open-loop
trajectory at a fixed step rate (``trigger.hz`` ``ctrl_hz``).
When the main loop stalls e.g. an LLM call for the high-level
subtask blocks for ~2 s on MPS the dispatch trigger fires only
once over that whole interval. Naively popping a single entry per
fire makes the robot lag further and further behind the planned
timeline, and a 50-step chunk would take ~125 s to drain instead
of ~1.7 s. Track real elapsed time between dispatches and pop
``round(elapsed * hz)`` entries, sending the most recent one. The
skipped intermediate joint targets are stale anyway the dynamixel
will smooth toward the latest goal position.
"""
robot_executor: Any = None
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
_last_dispatch_t: float | None = field(default=None, init=False)
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
import time as _time # noqa: PLC0415
# ``/vlm`` mode pauses dispatch — the robot holds its last
# commanded position while the operator runs VQA.
if state.get("mode", "action") != "action":
self._last_dispatch_t = None
return None
queue = state.get("action_queue")
if not queue:
# Reset wall-clock anchor when the queue is empty so the
# next chunk doesn't see a huge fake "elapsed" window.
self._last_dispatch_t = None
return None
now = _time.monotonic()
hz = getattr(self.trigger, "hz", 30.0)
if self._last_dispatch_t is None or hz <= 0:
n_to_pop = 1
else:
elapsed = now - self._last_dispatch_t
# ``max(1, ...)`` so we always pop at least one when the
# trigger fires; ``min(len(queue), ...)`` so we don't run
# off the end of the chunk.
n_to_pop = max(1, min(len(queue), int(round(elapsed * hz))))
self._last_dispatch_t = now
# Drain ``n_to_pop`` stale entries, keep only the latest as the
# action actually sent. The intermediate joint targets would
# all be ~1030 ms apart in chunk time — the robot can't track
# them individually anyway when the host loop is slow.
latest = None
for _ in range(n_to_pop):
if not queue:
break
latest = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
state["actions_dispatched"] = state.get("actions_dispatched", 0) + 1
if latest is not None and self.robot_executor is not None:
self.robot_executor(latest)
return None
# ---------------------------------------------------------------------------
# High-level (text) paths — all use policy.select_message
# ---------------------------------------------------------------------------
_LOC_TOKENIZER_CACHE: dict[str, Any] = {}
def _get_loc_tokenizer(tok_name: str, auto_tokenizer_cls: Any, register_loc_fn: Any) -> Any:
"""Return a loc-token-registered tokenizer, loading from disk only once.
``AutoTokenizer.from_pretrained`` + loc-token registration is expensive and
the result is immutable, so cache per ``tok_name``.
"""
tokenizer = _LOC_TOKENIZER_CACHE.get(tok_name)
if tokenizer is None:
tokenizer = register_loc_fn(auto_tokenizer_cls.from_pretrained(tok_name))
_LOC_TOKENIZER_CACHE[tok_name] = tokenizer
return tokenizer
def _build_text_batch(
policy: Any,
prompt_messages: list[dict[str, Any]],
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""Tokenize chat messages into the batch ``select_message`` expects.
PI052's backbone (PaliGemma) ships no chat template, so we train on
a plain role-prefixed concatenation built by
``PI052TextTokenizerStep``. We reuse that exact formatter so the
inference prefix matches training; ``add_generation_prompt`` appends
the bare ``Assistant: `` header the LM head continues from.
"""
import torch # noqa: PLC0415
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
_flatten_say_tool_calls,
_format_messages,
_strip_blocks,
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
# Register PaliGemma's <locDDDD> tokens so inference encoding /
# decoding sees them as single vocab ids — must match training.
# The tokenizer is read-only after registration, so cache it: rebuilding it
# from disk on every call dominated eval runtime (this runs twice per env
# per replan — subtask gen + action prompt).
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
prompt, _spans = _format_messages(messages)
if add_generation_prompt:
prompt = prompt + "Assistant: "
encoded = tokenizer(prompt, return_tensors="pt")
ids = encoded["input_ids"]
attn = encoded.get("attention_mask")
if attn is None and tokenizer.pad_token_id is not None:
attn = ids != tokenizer.pad_token_id
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
attn = attn.bool()
# Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently.
device = getattr(getattr(policy, "config", None), "device", None)
if device is not None:
try:
ids = ids.to(device)
if attn is not None and hasattr(attn, "to"):
attn = attn.to(device)
except Exception as exc: # noqa: BLE001
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
new = dict(m)
new.pop("stream", None)
new.pop("target", None)
return new
@dataclass
class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask.
Mirrors the ``high_level_subtask`` recipe layout exactly:
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
user: "Current subtask: ${subtask}" (if subtask present)
generate
assistant: <next subtask>
"""
policy: Any = None
observation_provider: Any = None
"""Same shape as ``LowLevelForward.observation_provider``. When
set, the resulting observation is merged into ``select_message``'s
batch so text generation runs against real video + state."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"):
return None
# ``/vlm`` mode pauses subtask generation along with the rest of
# the action loop.
if state.get("mode", "action") != "action":
return None
# Gate to chunk boundaries: only generate a fresh subtask when
# the action queue is empty (i.e. right before LowLevelForward
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
# and running it every loop iteration starves DispatchAction
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
# of 30/sec and the robot barely moves. Tying it to the same
# "queue empty" condition as the chunk refresh produces a
# clean sense → think → act cycle.
#
# Rearm the trigger when skipping so a low-hz schedule
# (e.g. ``--high_level_hz=0.2`` = once per 5 s) doesn't lose
# the slot: the trigger fires once on the timer but the brief
# queue-empty window almost never coincides, so without rearm
# HL would effectively never run.
queue = state.get("action_queue") or []
if len(queue) > 0:
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
# Per-chunk-boundary throttle: at each "queue empty" moment we
# increment a counter; subtask gen only fires once the counter
# reaches ``subtask_chunks_per_gen``. Lets the operator run e.g.
# 5 action chunks per subtask-gen so the LM head doesn't churn
# every 1.7 s (a fresh subtask while the previous one is still
# being executed is wasted compute *and* causes the action
# expert's flow trajectory to be re-planned mid-grasp).
chunks_per_gen = max(1, int(state.get("subtask_chunks_per_gen", 1) or 1))
# Initialise so the first chunk boundary fires immediately
# (counter starts at chunks_per_gen, decrements per skip,
# generates and resets when it hits 0).
if "_hl_chunks_until_gen" not in state:
state["_hl_chunks_until_gen"] = 0
if state["_hl_chunks_until_gen"] > 0:
state["_hl_chunks_until_gen"] -= 1
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
state["_hl_chunks_until_gen"] = chunks_per_gen - 1
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Default: greedy argmax, no min_new_tokens, no special-token
# suppression — matches training. Operator can override via
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
# on the CLI; useful for under-trained checkpoints whose LM
# head still favours EOS at position 0 (pre-trained chat
# backbone's short-turn prior hasn't been fully overridden
# by the fine-tuning supervision yet).
msg = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
# Subtasks never legitimately contain PaliGemma ``<loc>``
# tokens — suppress them so a checkpoint whose LM head
# has drifted toward the pretrained loc-prior falls back
# to its (still-correct) text mass.
suppress_loc_tokens=True,
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
# repeats. Memorisation collapse looks like "same accepted
# subtask N times in a row" or "gibberish_count rising while
# current_subtask is stuck". The state panel renders these.
state["last_subtask_raw"] = msg or ""
# Persistent empty completion is its own failure mode (model
# immediately EOS-es from the chat-template generation
# prompt) — surface it once every N occurrences so the
# operator can distinguish "generation failing silently"
# from "generating fine but filter rejecting".
if not msg:
empties = state.get("subtask_empty_count", 0) + 1
state["subtask_empty_count"] = empties
if empties == 1 or empties % 5 == 0:
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
if debug:
push_log(
state,
f" [info] subtask gen empty (×{empties}); {debug}",
)
else:
push_log(
state,
f" [info] subtask gen returned empty (×{empties}) — "
"no tokens generated (head EOS-ing before any "
"non-special token).",
)
if msg and _looks_like_gibberish(msg):
# Bump a counter so the operator can see the model is
# struggling without spamming the log every tick. A first
# rejection still logs once so the failure is visible.
count = state.get("subtask_gibberish_count", 0) + 1
state["subtask_gibberish_count"] = count
if count == 1 or count % 30 == 0:
push_log(
state,
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
)
return None
if msg:
prev_subtask = state.get("current_subtask")
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
# Stash the just-completed subtask so ``MemoryUpdateFwd``
# can drop it into its prompt as ``Completed subtask:``
# — the recipe binds ``completed_subtask`` to
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
# that was active *before* the change.
if prev_subtask:
state["prior_subtask"] = prev_subtask
# Subtask change is a downstream trigger.
state.setdefault("events_this_tick", []).append("subtask_change")
state["subtask_repeat_count"] = 0
else:
# Same accepted string regenerated — memorisation tell.
# Once this counter climbs past a few, you're seeing
# the model unable to move past the current subtask
# despite the chunk having drained (visual scene may
# have changed but the LM is replaying training
# tokens).
state["subtask_repeat_count"] = (
state.get("subtask_repeat_count", 0) + 1
)
# Silently skip empty completions — common when the model
# warms up or generates only EOS; logging it every tick at
# ctrl_hz is just noise.
return None
@dataclass
class MemoryUpdateFwd(InferenceStep):
"""On subtask boundary, refresh the compressed memory.
Mirrors the ``memory_update`` recipe layout exactly:
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if prior memory)
user: "Completed subtask: ${completed_subtask}" (if subtask)
generate
assistant: <new memory>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
# Don't consume the event — multiple steps may want to react.
if self.policy is None:
return None
ctx = _msgs_for_memory(state)
observation = _maybe_observation(self.observation_provider)
new_memory = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="memory gen",
suppress_loc_tokens=True,
)
state["last_memory_raw"] = new_memory or ""
if new_memory and _looks_like_gibberish(new_memory):
count = state.get("memory_gibberish_count", 0) + 1
state["memory_gibberish_count"] = count
push_log(
state,
f" [info] memory gen rejected (gibberish ×{count}): {new_memory[:60]!r}",
)
return None
if new_memory:
set_if_changed(state, "current_memory", new_memory, label="memory")
return None
@dataclass
class UserInterjectionFwd(InferenceStep):
"""On stdin interjection, refresh the plan + emit a paired ``say``.
Mirrors the ``user_interjection_response`` recipe layout exactly:
user: "${task}"
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
user: "${interjection}" (the new utterance)
generate
assistant: <plan + <say>...</say>>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_interjection"):
return None
ctx = _msgs_for_interjection(state)
observation = _maybe_observation(self.observation_provider)
out = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="plan/say gen",
suppress_loc_tokens=True,
)
if not out:
# Don't log every empty completion — happens repeatedly on
# MPS during warm-up and floods the panel. The user can
# re-trigger by typing again.
return None
if _looks_like_gibberish(out):
count = state.get("plan_gibberish_count", 0) + 1
state["plan_gibberish_count"] = count
push_log(
state,
f" [info] plan/say gen rejected (gibberish ×{count}): {out[:60]!r}",
)
return None
# Heuristic split: model is trained to emit one assistant turn
# carrying both plan text AND a `say` tool call. Look for a
# "<say>...</say>" or "say(...)" marker; fall back to whole
# text → plan, no speech.
plan_text, speech_text = _split_plan_and_say(out)
if plan_text and _looks_like_gibberish(plan_text):
plan_text = ""
if plan_text:
set_if_changed(state, "current_plan", plan_text, label="plan")
if speech_text:
push_log(state, f" speech: {speech_text}")
state.setdefault("tool_calls_pending", []).append(
{
"type": "function",
"function": {"name": "say", "arguments": {"text": speech_text}},
}
)
state.setdefault("events_this_tick", []).append("tool_call_pending")
# Mark interjection consumed.
state["recent_interjection"] = None
return None
@dataclass
class AskVQAFwd(InferenceStep):
"""On stdin question, answer a frame-grounded VQA.
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
turn carrying just the VQA question, plus the camera image block
in training (we drop the image at inference because the dataset's
image preprocessing doesn't match SmolVLM's vision tower input).
user: <question>
generate
assistant: <vqa answer>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_vqa_query"):
return None
question = state.get("recent_vqa_query")
if not question:
return None
ctx = _msgs_for_vqa(question)
observation = _maybe_observation(self.observation_provider)
answer = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="vqa gen",
)
# VQA answers are intentionally JSON-like during training, so
# ``_looks_like_gibberish`` would false-positive on them. Keep
# the answer as-is — the VQA panel line lets the user judge.
if answer:
push_log(state, f" vqa: {answer}")
state["recent_vqa_query"] = None
return None
# ---------------------------------------------------------------------------
# Tool dispatch
# ---------------------------------------------------------------------------
@dataclass
class DispatchToolCalls(InferenceStep):
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
tools: dict[str, Any] = field(default_factory=dict)
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
take_event(state, "tool_call_pending")
pending = state.get("tool_calls_pending") or []
for call in pending:
try:
fn = (call or {}).get("function") or {}
name = fn.get("name")
args = fn.get("arguments") or {}
tool = self.tools.get(name)
if tool is None:
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
continue
tool.call(args)
except Exception as exc: # noqa: BLE001
push_log(state, f" [error] tool dispatch failed: {exc}")
state["tool_calls_pending"] = []
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _looks_like_gibberish(text: str) -> bool:
"""Heuristically detect generation that's clearly off the rails.
Memorised models can collapse to dominant-mode outputs when the
prompt drifts even slightly from training distribution. Reject:
* empty / whitespace-only
* too few alphabetic characters (mostly punctuation)
* a single character repeated past the threshold
* starts with ``":"`` and contains no letters
* too few unique tokens e.g. ``"the"``, ``"the the the"``,
``"Ass\\n::\\nthe"`` (the collapse seen on real-robot frames
where the model emits one or two memorised tokens repeatedly)
* chat-template fragment leakage (``Assistant:``, ``User:``,
``Ass\\n``)
Real subtasks look like ``"close the gripper to grasp the blue
cube"`` — multiple unique alphabetic tokens, no role-marker
fragments. Anything materially shorter than that is rejected.
"""
if not text or not text.strip():
return True
stripped = text.strip()
alpha = sum(1 for c in stripped if c.isalpha())
if alpha < max(3, len(stripped) // 8):
return True
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
return True
# Single repeating char: e.g. ``""""""``.
if len(set(stripped)) <= 2 and len(stripped) > 4:
return True
# Chat-template fragment leakage — the model emits ``Ass``,
# ``Assistant:``, ``User:``, often with extra newlines/colons.
# Reject if the cleaned text is mostly role-marker shards.
cleaned = stripped.replace("\n", " ").replace(":", " ")
for marker in ("Assistant", "User", "Ass "):
if marker in cleaned and len(cleaned.split()) < 4:
return True
tokens = [t for t in cleaned.split() if any(c.isalpha() for c in t)]
unique_alpha = {t.lower() for t in tokens}
# Short degenerate output — model stuck on ``the`` or a couple of
# memorised single-token continuations.
if len(unique_alpha) < 3 and len(stripped) < 80:
return True
# Long repetition collapse — the LM head loops an n-gram for the
# whole generation budget ("the arm the arm … the the the the").
# Length-independent: many tokens but a tiny unique ratio. The
# earlier ``< 80`` check missed these because the looped string
# blows well past 80 chars.
if len(tokens) >= 8 and len(unique_alpha) <= max(3, len(tokens) // 10):
return True
return False
def _control_context_messages(
state: dict[str, Any],
*,
include_completed: bool = False,
extra_user: str | None = None,
) -> list[dict[str, Any]]:
"""Build a chat-template-ready prompt from current runtime state.
Mirrors what ``subtasks_vqa.yaml`` renders into ``${task}\nPlan:
${plan}\nMemory: ${memory}`` for the high-level branches.
"""
# Always emit ``Plan: `` / ``Memory: `` labels — even with empty
# values — to mirror the training-time recipe substitution.
task = state.get("task") or ""
plan = state.get("current_plan") or ""
memory = state.get("current_memory") or ""
parts = [task, f"Plan: {plan}", f"Memory: {memory}"]
if include_completed and state.get("current_subtask"):
parts.append(f"Completed subtask: {state['current_subtask']}")
head = "\n".join(parts)
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
if extra_user:
msgs.append({"role": "user", "content": extra_user})
return msgs
# ---------------------------------------------------------------------------
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
# message layout in ``subtasks_vqa.yaml`` so the chat-templated
# prompt at inference matches what the model saw during training.
# Generic ``_control_context_messages`` is kept around as a fallback
# for ad-hoc callers but the four high-level steps now use these.
# ---------------------------------------------------------------------------
def _hirobot_user_head(state: dict[str, Any]) -> str:
"""Build the ``task\\nPlan: …\\nMemory: …`` user content string.
Mirrors what the recipe renders at training time, where
``language_render._substitute`` substitutes empty strings for
missing ``${plan}`` / ``${memory}`` bindings i.e. the
``Plan: `` / ``Memory: `` prefix labels are *always* in the
user turn, even when their values aren't set yet. Skipping them
here (the previous behaviour) produced a different prompt shape
on early frames before plan / memory are populated and on
samples where the dataset has no plan / memory annotation.
"""
task = state.get("task") or ""
plan = state.get("current_plan") or ""
memory = state.get("current_memory") or ""
return f"{task}\nPlan: {plan}\nMemory: {memory}"
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``high_level_subtask`` recipe layout — predict the subtask from the
task. The v-current recipe's user turn is just ``${task}`` (plan and
memory are not trained), so the inference prompt is the bare task
no ``Plan: `` / ``Memory: `` lines.
"""
return [{"role": "user", "content": state.get("task") or ""}]
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
Recipe layout (``subtask_mem.yaml``):
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if_present prior)
user: "Completed subtask: ${completed}" (if_present completed)
assistant: predicts new memory
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
``state['current_memory']`` is the memory the policy last emitted
(= the ``prior_memory`` binding at training), and
``state['prior_subtask']`` is the subtask that just got replaced
(= the ``completed_subtask`` binding at training).
"""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""},
]
prior_memory = state.get("current_memory")
if prior_memory:
msgs.append(
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
)
completed_subtask = state.get("prior_subtask")
if completed_subtask:
msgs.append(
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
)
return msgs
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``user_interjection_response`` recipe layout."""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""}
]
if state.get("current_plan"):
msgs.append(
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
)
interjection = state.get("recent_interjection")
if interjection:
msgs.append({"role": "user", "content": interjection})
return msgs
def _msgs_for_plan(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``plan_generation`` recipe layout — bare task → plan.
The assistant turn is the generation target, so we only render
the user turn at inference; the runtime appends the predicted
plan after sampling.
"""
return [{"role": "user", "content": state.get("task") or ""}]
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
return [{"role": "user", "content": question}]
def _maybe_observation(provider: Any) -> dict | None:
"""Pull one observation from ``provider`` if it's set, else ``None``.
Errors from the provider are logged at debug level and swallowed
text generation still runs (in text-only mode) so a flaky frame
source doesn't kill the REPL.
"""
if provider is None:
return None
try:
return provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s — falling back to text-only", exc)
return None
def _generate_with_policy(
policy: Any,
messages: list[dict[str, Any]],
*,
observation: dict | None = None,
state: dict[str, Any] | None = None,
label: str = "select_message",
min_new_tokens: int = 0,
temperature: float = 0.0,
top_p: float = 1.0,
suppress_loc_tokens: bool = False,
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
When ``observation`` carries ``observation.images.*`` and
``observation.state``, those are merged into the batch so
``select_message`` runs the same VLM prefix the policy was trained
on. Without an observation the runtime falls back to a text-only
prompt the text head still runs, but generations may drift from
the training distribution.
Failures are surfaced both to the module logger (``warning``) and,
when ``state`` is given, to the runtime's user-visible log via
:func:`push_log`, so the REPL no longer "looks dead" when
something goes wrong inside generation.
"""
if not hasattr(policy, "select_message"):
if state is not None:
push_log(state, f" [warn] policy has no select_message — skipping {label}")
return ""
text_batch = _build_text_batch(policy, messages)
try:
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
batch: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
}
if observation:
for k, v in observation.items():
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
batch[k] = v
kwargs: dict[str, Any] = {
"tokenizer": text_batch["tokenizer"],
"min_new_tokens": min_new_tokens,
"temperature": temperature,
"top_p": top_p,
}
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
return policy.select_message(batch, **kwargs)
except Exception as exc: # noqa: BLE001
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
if state is not None:
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
return ""
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
def _split_plan_and_say(text: str) -> tuple[str, str]:
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
The training-time tool-call serializer wraps ``say(text="")`` in a
deterministic textual marker so prefix-LM-style training learns to
emit it. The runtime parses it back here. If no marker is present,
the entire text is treated as plan with no speech.
"""
if not text:
return "", ""
match = _SAY_RE.search(text)
if not match:
return text.strip(), ""
speech = match.group(1).strip().strip('"').strip("'")
plan = (text[: match.start()] + text[match.end() :]).strip()
return plan, speech
@@ -1,134 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trigger primitives for PI052's multi-rate inference runtime.
Mirrors the plan's Section "Runtime orchestration": each
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
whether the step fires. Two trigger flavours cover all the cadences
the canonical recipe needs:
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
* :class:`EventTrigger` for one-shot reactions (subtask boundary
memory update; user interjection plan refresh; user VQA query
vqa answer; pending tool call dispatcher)
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
every step shares a single time source.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any, Protocol
@dataclass
class Tick:
"""Single tick from :class:`TickClock`. Carries time references the
runtime steps consume to gate themselves."""
index: int
"""Monotonic counter — increments by one per tick."""
monotonic_seconds: float
"""``time.monotonic()`` at the start of this tick."""
@dataclass
class TickClock:
"""Drives the runtime loop at up to ``max_rate_hz``.
Sleeps just enough between :meth:`advance` calls to enforce the
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
"""
max_rate_hz: float = 50.0
_index: int = field(default=0, init=False)
_last_seconds: float | None = field(default=None, init=False)
def advance(self) -> Tick:
period = 1.0 / max(self.max_rate_hz, 0.1)
now = time.monotonic()
if self._last_seconds is not None:
sleep_for = (self._last_seconds + period) - now
if sleep_for > 0:
time.sleep(sleep_for)
now = time.monotonic()
self._last_seconds = now
self._index += 1
return Tick(index=self._index, monotonic_seconds=now)
class Trigger(Protocol):
"""Decide whether the next ``InferenceStep`` should fire."""
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
@dataclass
class HzTrigger:
"""Fire at most ``hz`` times per second.
A step that gates further (e.g. ``HighLevelSubtaskFwd`` skipping
when the action queue is non-empty) and wants the trigger to
retry next tick instead of waiting a full period can call
:meth:`rearm` from inside ``run``. Without this, a low-hz trigger
(e.g. ``hz=0.2`` = once per 5 s) almost never coincides with the
brief queue-empty window and the step never fires at all.
"""
hz: float
_last_seconds: float | None = field(default=None, init=False)
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
period = 1.0 / max(self.hz, 1e-6)
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
self._last_seconds = tick.monotonic_seconds
return True
return False
def rearm(self) -> None:
"""Mark the trigger as not having fired, so the next tick re-evaluates.
Used by a step that decided to skip after ``should_fire`` already
committed the firing keeps the cadence honest without losing
the slot.
"""
self._last_seconds = None
@dataclass
class EventTrigger:
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
The runtime fills ``events_this_tick`` once per tick from:
* stdin / network input (``user_interjection``, ``user_vqa_query``,
``stop``)
* internal state transitions (``subtask_change``,
``tool_call_pending``)
The list is consumed (cleared at the end of the tick) so events
fire at most once.
"""
event_name: str
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
events: list[str] = state.get("events_this_tick") or []
return self.event_name in events
-127
View File
@@ -1,127 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rich-based REPL layout for the PI052 runtime.
Two-zone terminal layout:
[chat scrollback user messages / robot responses, scrolls naturally]
State
task please clean up the kitchen
subtask grasp the handle of the sponge
plan 1. grasp sponge 2. wipe 3. tidy
memory sponge picked up; counter still dirty
> _
The state panel re-renders on every state change. Chat lines are
``console.print``'d above the live region so they accumulate naturally
in scrollback. Implemented with :class:`rich.live.Live` plus
:func:`rich.console.Console.input` for the prompt when an input is
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
panel for cursor position.
"""
from __future__ import annotations
from typing import Any
try: # rich is optional; only required for the interactive REPL.
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
_HAS_RICH = True
except ImportError: # pragma: no cover
_HAS_RICH = False
Console = Any # type: ignore[assignment]
Panel = Any # type: ignore[assignment]
Table = Any # type: ignore[assignment]
Text = Any # type: ignore[assignment]
_STATE_KEYS = (
("task", "task"),
("current_subtask", "subtask"),
("current_plan", "plan"),
("current_memory", "memory"),
)
def make_state_panel(state: dict[str, Any]) -> Any:
"""Render the persistent state panel for the live region.
Returns a :class:`rich.panel.Panel`. Caller passes it to
``Live.update(panel)`` whenever the state changes.
"""
if not _HAS_RICH:
raise RuntimeError(
"rich is required for the interactive REPL. "
"`pip install rich` (it's a transitive dep of lerobot)."
)
table = Table.grid(padding=(0, 2), expand=True)
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
table.add_column(justify="left")
for key, label in _STATE_KEYS:
value = state.get(key)
if value is None:
rendered = Text("(not set)", style="dim italic")
else:
rendered = Text(str(value), style="bold")
table.add_row(label, rendered)
queue = state.get("action_queue")
queue_len = len(queue) if hasattr(queue, "__len__") else 0
pending = state.get("tool_calls_pending") or []
footer = Text.assemble(
("queued actions: ", "dim"),
(str(queue_len), "bold cyan"),
(" pending tool calls: ", "dim"),
(str(len(pending)), "bold magenta"),
)
table.add_row("", footer)
run_mode = state.get("mode", "action")
mode_tag = (
"[green]action[/]" if run_mode == "action" else "[yellow]paused[/]"
)
return Panel(
table,
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
border_style="cyan",
)
def print_user_line(console: Any, line: str) -> None:
"""Append a user-typed line to the chat scrollback."""
if not _HAS_RICH:
print(f"you: {line}", flush=True)
return
console.print(f"[bold cyan]you:[/] {line}")
def print_robot_lines(console: Any, lines: list[str]) -> None:
"""Append robot/runtime log lines to the chat scrollback."""
if not _HAS_RICH:
for line in lines:
print(f"robot: {line.lstrip()}", flush=True)
return
for line in lines:
# The runtime uses leading whitespace + "label: text"; render
# the label in green and the value in default for readability.
stripped = line.lstrip()
if ":" in stripped:
label, _, value = stripped.partition(":")
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
else:
console.print(f"[bold green]robot:[/] {stripped}")
-423
View File
@@ -1,423 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interactive VQA for the PI052 runtime.
In ``/vlm`` mode a typed line is treated as a VQA question. This module
runs the full interactive flow:
1. pull the current observation and list available cameras,
2. ask the operator which camera to ground the question on,
3. generate the answer with the VLM conditioned on that one camera,
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
point (``keypoint``), draw the overlay on the camera frame, save a
PNG to ``./vqa_overlays/`` and auto-open it.
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
(see ``lerobot.annotations.steerable_pipeline.validator``):
* ``bbox`` ``{"detections": [{"label", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}, ...]}``
* ``keypoint`` ``{"label", "point_format": "xy", "point": [x, y]}``
* ``count`` / ``attribute`` / ``spatial`` text-only, no overlay.
"""
from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import sys
import time
import webbrowser
from pathlib import Path
from typing import Any
from .runtime_state import push_log
logger = logging.getLogger(__name__)
_IMAGE_PREFIX = "observation.images."
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
# normalized to the image axis) instead of pixel-coordinate JSON, so the
# answer string the runtime parses can be e.g.
# ``<loc0512><loc0301> blue cube`` (point) or
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
# Iteration order for shape matching — most specific keys first so an
# answer is classified deterministically.
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
_BBOX_COLOR = (255, 64, 64)
_POINT_COLOR = (64, 220, 64)
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def available_cameras(observation: dict | None) -> list[str]:
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
if not observation:
return []
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
def camera_short_name(camera_key: str) -> str:
"""Strip the ``observation.images.`` prefix for display."""
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
def prompt_camera_choice(
cameras: list[str],
*,
input_fn: Any = input,
print_fn: Any = print,
) -> str | None:
"""Ask the operator which camera frame to draw a VQA overlay on.
Accepts either the menu number or the (short or full) camera name.
A single-camera setup auto-selects without prompting. Returns the
chosen ``observation.images.*`` key, or ``None`` if the operator
cancels / gives an invalid answer.
"""
if not cameras:
return None
if len(cameras) == 1:
return cameras[0]
print_fn("Draw the result on which camera?")
for i, cam in enumerate(cameras, 1):
print_fn(f" [{i}] {camera_short_name(cam)}")
try:
raw = str(input_fn("camera> ")).strip()
except (EOFError, KeyboardInterrupt):
return None
if not raw:
return cameras[0]
if raw.isdigit():
idx = int(raw) - 1
return cameras[idx] if 0 <= idx < len(cameras) else None
for cam in cameras:
if raw == cam or raw == camera_short_name(cam):
return cam
return None
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def _loc_to_norm(idx: int) -> float:
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
return max(0.0, min(1023.0, float(idx))) / 1023.0
def parse_loc_answer(answer: str) -> dict | None:
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
PI052 trains spatial answers in PaliGemma's native detection
vocabulary, label-first: a point is ``<label> <locY><locX>``, a box
is ``<label> <locY0><locX0><locY1><locX1>``, and multiple boxes are
joined by `` ; `` (e.g. ``cube <loc..><loc..><loc..><loc..> ; box
<loc..><loc..><loc..><loc..>``). Loc-first formats are also accepted
this parser strips loc tokens and treats the remainder as the
label, so order is irrelevant. Coordinates come back *normalized*
([0, 1]); the overlay denormalizes them against the chosen camera
frame's pixel size.
Returns ``{"kind", "payload", "normalized": True}`` on success
(``payload`` mirrors the JSON shapes so the overlay code is shared),
or ``None`` when the answer carries no ``<loc>`` tokens.
"""
if not answer or "<loc" not in answer:
return None
segments = [seg for seg in answer.split(";") if "<loc" in seg]
points: list[tuple[float, float, str]] = []
boxes: list[tuple[float, float, float, float, str]] = []
for seg in segments:
locs = [int(m) for m in _LOC_RE.findall(seg)]
label = _LOC_RE.sub("", seg).strip()
if len(locs) == 2:
y, x = (_loc_to_norm(v) for v in locs[:2])
points.append((x, y, label))
elif len(locs) >= 4:
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
boxes.append((x1, y1, x2, y2, label))
if boxes:
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
for (x1, y1, x2, y2, lbl) in boxes
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
if len(points) == 1:
x, y, lbl = points[0]
return {
"kind": "keypoint",
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
"normalized": True,
}
if points: # several bare points → treat as detections-as-points
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
return None
def parse_vqa_answer(answer: str) -> dict | None:
"""Parse a VQA answer string into ``{"kind", "payload"}``.
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
spatial answers are detected first (PI052 trains them in that native
format). Returns ``None`` when the answer is neither ``<loc>`` text
nor a parseable JSON object.
"""
if not answer or not answer.strip():
return None
loc_parsed = parse_loc_answer(answer)
if loc_parsed is not None:
return loc_parsed
try:
payload = json.loads(answer)
except (ValueError, TypeError):
return None
if not isinstance(payload, dict):
return None
try:
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
VQA_ANSWER_SHAPES,
)
shapes = VQA_ANSWER_SHAPES
except ImportError: # pragma: no cover - annotation extra not installed
shapes = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
keys = set(payload)
for kind in _SHAPE_ORDER:
required = shapes.get(kind)
if required and required <= keys:
return {"kind": kind, "payload": payload}
return {"kind": "unknown", "payload": payload}
def answer_has_overlay(parsed: dict | None) -> bool:
"""True iff ``parsed`` carries drawable spatial coordinates."""
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def observation_image_to_pil(image_tensor: Any) -> Any:
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
The runtime observation stores images as ``(1, C, H, W)`` (or
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
``image_array_to_pil_image`` which handles the CHWHWC transpose and
the floatuint8 scaling.
"""
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
arr = image_tensor
if hasattr(arr, "detach"):
arr = arr.detach().cpu()
if hasattr(arr, "numpy"):
arr = arr.numpy()
while arr.ndim > 3: # drop leading batch dim(s)
arr = arr[0]
return image_array_to_pil_image(arr).convert("RGB")
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
``unknown``) are returned as an unmodified copy. When ``parsed`` has
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
coordinates are scaled to the image's pixel size.
"""
from PIL import ImageDraw # noqa: PLC0415
img = image.convert("RGB").copy()
kind = parsed.get("kind")
payload = parsed.get("payload") or {}
draw = ImageDraw.Draw(img)
w, h = img.size
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
if kind == "bbox":
for det in payload.get("detections") or []:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
x1, x2 = x1 * sx, x2 * sx
y1, y2 = y1 * sy, y2 * sy
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
label = str(det.get("label", "")).strip()
if label:
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
elif kind == "keypoint":
point = payload.get("point")
if isinstance(point, list | tuple) and len(point) == 2:
try:
x, y = float(point[0]) * sx, float(point[1]) * sy
except (TypeError, ValueError):
return img
r = 6
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
label = str(payload.get("label", "")).strip()
if label:
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
return img
def _open_file(path: Path) -> None:
"""Best-effort open ``path`` in the OS default viewer."""
try:
if sys.platform == "darwin":
subprocess.run(["open", str(path)], check=False)
elif sys.platform.startswith("linux"):
subprocess.run(["xdg-open", str(path)], check=False)
elif os.name == "nt":
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
else: # pragma: no cover - exotic platform
webbrowser.open(path.resolve().as_uri())
except Exception as exc: # noqa: BLE001
logger.debug("could not auto-open %s: %s", path, exc)
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
path = out / f"vqa_{int(time.time() * 1000)}.png"
image.save(path)
_open_file(path)
return path
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------
def handle_vqa_query(
*,
policy: Any,
observation_provider: Any,
question: str,
state: dict[str, Any],
input_fn: Any = input,
print_fn: Any = print,
) -> None:
"""Run one interactive VQA question end to end.
Called synchronously from the input layer while the runtime is in
``/question`` mode (the action loop is gated off, so the policy is
not in concurrent use). Progress is reported via both
:func:`push_log` (REPL panel scrollback) and ``print_fn`` (direct
stdout) in autonomous question mode the panel redraw is suspended,
so the direct print is what the operator actually sees.
"""
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
def report(line: str) -> None:
"""Surface a line both to the panel scrollback and to stdout."""
push_log(state, line)
try:
print_fn(line)
except Exception: # noqa: BLE001
pass
if policy is None or not hasattr(policy, "select_message"):
report(" [warn] vqa: policy has no select_message — skipping")
return
observation: dict | None = None
if observation_provider is not None:
try:
observation = observation_provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s", exc)
# Feed the FULL observation (every camera + state) to the VLM. The
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
# stripped before tokenization — the actual frames reach the model
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
# consumes *all* ``config.image_features`` regardless of which
# camera the sub-recipe was tagged for. So the model always sees
# every camera; the operator never has to name one to ask.
answer = _generate_with_policy(
policy,
_msgs_for_vqa(question),
observation=observation,
state=state,
label="vqa gen",
)
if not answer:
report(" [info] vqa gen returned empty")
return
report(f" vqa: {answer}")
parsed = parse_vqa_answer(answer)
if not answer_has_overlay(parsed):
if parsed is None:
report(" [info] vqa answer is not JSON — no overlay")
return
# The answer carries a bounding box / point. Its pixel coordinates
# are camera-specific and the text answer doesn't say which camera,
# so ask the operator *now* — only when there is actually something
# to draw — which camera frame to render the overlay on.
cameras = available_cameras(observation)
if observation is None or not cameras:
report(" [info] no camera image — cannot draw overlay")
return
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
if chosen is None:
report(" [info] overlay skipped — no camera selected")
return
try:
pil = observation_image_to_pil(observation[chosen])
overlay = draw_vqa_overlay(pil, parsed)
path = save_and_open_overlay(overlay)
report(f" vqa overlay ({camera_short_name(chosen)}) saved: {path}")
except Exception as exc: # noqa: BLE001
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
report(f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")
File diff suppressed because it is too large Load Diff
@@ -1,198 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 pre/post-processor factory.
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
rename observations
add batch dim
relative-action prep (inherited from π0.5)
NormalizerProcessorStep
RenderMessagesStep recipe messages, target_message_indices,
message_streams (PR 1 of the steerable
stack)
PI052TextTokenizerStep messages input_ids + label mask +
predict_actions
DeviceProcessorStep
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
so unannotated datasets keep working.
Post-processor is unchanged from π0.5.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import (
AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
# RenderMessagesStep is intentionally not re-exported from
# ``lerobot.processor`` because it pulls in optional language-stack deps;
# import it directly.
from lerobot.processor.render_messages_processor import RenderMessagesStep
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
from .configuration_pi052 import PI052Config
from .text_processor_pi052 import PI052TextTokenizerStep
def make_pi052_pre_post_processors(
config: PI052Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_repo_id: str | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build PI0.5-v2's pre/post-processor pipelines.
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
"""
if not config.recipe_path:
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
recipe = _load_recipe(config.recipe_path)
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
relative_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
RenderMessagesStep(recipe=recipe),
PI052TextTokenizerStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
),
]
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
# Only inserted when explicitly enabled — keeps the post-training-
# style recipe (flow + text) as the default. When on, the step
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
if getattr(config, "enable_fast_action_loss", False):
# Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the
# tokenizer on this dataset's action distribution rather than
# using the universal codebook off the shelf. We do this once
# and cache to disk, keyed on (dataset, base, n_samples).
action_tokenizer_path = config.action_tokenizer_name
if (
getattr(config, "auto_fit_fast_tokenizer", False)
and dataset_repo_id is not None
):
from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415
cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser()
try:
action_tokenizer_path = fit_fast_tokenizer(
dataset_repo_id=dataset_repo_id,
cache_dir=cache_dir,
base_tokenizer_name=config.action_tokenizer_name,
n_samples=config.fast_tokenizer_fit_samples,
chunk_size=config.chunk_size,
)
except Exception as exc: # noqa: BLE001
import logging # noqa: PLC0415
logging.getLogger(__name__).warning(
"FAST tokenizer fit failed (%s) — falling back to "
"the universal base tokenizer %r. Train will still "
"work but compression will be suboptimal.",
exc, config.action_tokenizer_name,
)
input_steps.append(
ActionTokenizerProcessorStep(
action_tokenizer_name=action_tokenizer_path,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
)
)
input_steps.append(DeviceProcessorStep(device=config.device))
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AbsoluteActionsProcessorStep(
enabled=config.use_relative_actions,
relative_step=relative_step,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
def _load_recipe(path_str: str) -> TrainingRecipe:
"""Resolve ``path_str`` to a ``TrainingRecipe``.
Accepts an absolute path or a path relative to
``src/lerobot/configs/``.
"""
p = Path(path_str)
if not p.is_absolute() and not p.exists():
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
configs_dir = Path(_recipe_module.__file__).resolve().parent
candidate = configs_dir / path_str
if candidate.exists():
p = candidate
return TrainingRecipe.from_yaml(p)
@@ -1,641 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 text-tokenisation step.
PaliGemma is *not* chat-pretrained, so we can't lean on
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
messages as plain text with simple ``User: ... Assistant: ...`` role
delimiters matching the prompt format π0.5 uses in the paper
(``Task: ... State: ... Action: ...``).
Outputs:
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` the
concatenated prompt tokenised by the PaliGemma tokenizer (the same
one ``processor_pi05`` already uses).
* ``text_labels`` same shape as token ids, ``-100`` everywhere except
positions belonging to messages whose index is in
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
those positions via the PaliGemma ``lm_head``.
* ``predict_actions`` bool tensor, ``True`` iff any of the rendered
target messages has ``message_streams[i] == "low_level"``.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
logger = logging.getLogger(__name__)
def discretize_state_str(state_row: Any) -> str:
"""Discretize a single normalized state vector into 256 bins, space-joined.
Mirrors pi05's ``Pi05PrepareStateTokenizerProcessorStep`` (same bins /
convention) so pi052's low-level action prompt carries proprioception in
the exact format pi05 was trained on. Expects state already normalized by
the upstream ``NormalizerProcessorStep``.
"""
arr = state_row.detach().cpu().numpy() if hasattr(state_row, "detach") else np.asarray(state_row)
disc = np.digitize(arr, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
return " ".join(str(int(x)) for x in disc.reshape(-1).tolist())
def _state_row_at(state_all: Any, pos: int) -> Any:
"""Select the per-sample state row from a (possibly batched) state tensor."""
if state_all is None:
return None
if hasattr(state_all, "ndim") and state_all.ndim >= 2:
return state_all[pos]
return state_all
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
b["text"]
for b in content
if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str)
]
return "\n".join(parts)
return ""
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
"""Serialize assistant ``say`` tool calls into a ``<say>...</say>`` marker.
PaliGemma's flat text prompt has no notion of structured tool calls,
and ``_format_messages`` only reads ``role`` / ``content`` so
without this a ``say`` tool call is dropped entirely and never
supervised. Rewriting it into the content text as a ``<say>...</say>``
marker lets the LM head learn to emit it; the runtime parses it back
via ``_split_plan_and_say``. Messages without ``say`` tool calls are
returned unchanged (the structured calls, if any, are still dropped).
"""
tool_calls = message.get("tool_calls")
if not tool_calls:
return message
say_texts: list[str] = []
for call in tool_calls:
if not isinstance(call, dict):
continue
fn = call.get("function") or {}
if fn.get("name") != "say":
continue
args = fn.get("arguments")
if isinstance(args, str):
try:
import json # noqa: PLC0415
args = json.loads(args)
except (ValueError, TypeError):
args = {}
text = args.get("text", "") if isinstance(args, dict) else ""
if text:
say_texts.append(str(text))
new = dict(message)
new.pop("tool_calls", None)
if not say_texts:
return new
base = _content_to_text(new.get("content")).strip()
marker = "".join(f"<say>{t}</say>" for t in say_texts)
new["content"] = f"{base}\n{marker}" if base else marker
return new
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Normalise a message's content to a plain string.
The recipe renderer can emit ``content`` as a string OR as a list
of HF-style multimodal blocks (``{type: text, text: ...}``,
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
only consume strings, so we flatten: drop image blocks (cameras
flow through ``observation.images.*`` separately) and join text
block texts.
"""
new = dict(message)
new.pop("stream", None)
new.pop("target", None)
content = new.get("content")
if content is None:
new["content"] = ""
elif isinstance(content, str):
pass
elif isinstance(content, list):
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
t = block.get("text", "")
if isinstance(t, str):
parts.append(t)
new["content"] = "\n".join(parts)
else:
new["content"] = str(content)
return new
def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
if value is None:
return [None] * batch_size
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return [int(value.item())] * batch_size
values = value.reshape(-1).tolist()
return [int(v) for v in values[:batch_size]]
if isinstance(value, (list, tuple)):
if len(value) == 1:
return _sample_indices(value[0], batch_size)
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
return [int(value)] * batch_size
# ---------------------------------------------------------------------------
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
#
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
# answers are stored as JSON in Qwen2.5-VL's grounding convention:
# **01000 normalized coordinates**, NOT pixels. (Verified empirically
# on the published datasets: x and y both span 0..1000 with ~30% of
# values exceeding the camera's pixel dimensions — they're not pixels.)
# Converting to ``<loc>`` is therefore camera-resolution-independent:
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
# not in the dataset — so the dataset keeps the raw JSON and stays
# backbone-agnostic.
# ---------------------------------------------------------------------------
# The 01000 scale Qwen2.5-VL emits for grounding coordinates.
_VQA_COORD_SCALE = 1000.0
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
(detection / pointing) tokens, but the *stock* tokenizer does NOT
match them when encoding raw text it BPE-splits ``<loc0162>`` into
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
the LM head on a ``<loc>`` target then supervises those 7 generic
BPE pieces instead of one detection-vocab id, the LM head learns to
emit the *character sequence*, and those pieces' logits dominate
other turns (the ``<loc>``-salad on subtasks). Registering the loc
tokens once makes them tokenize as their single ids (256000+idx),
leveraging PaliGemma's detection prior properly. Idempotent.
"""
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
return tokenizer
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
return tokenizer
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
return f"<loc{max(0, min(1023, idx)):04d}>"
def _vqa_answer_to_loc(answer: dict[str, Any]) -> str | None:
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
Input coordinates are in Qwen2.5-VL's 01000 normalized space (see
module-level note). y is emitted before x for each coordinate pair
(PaliGemma convention), with the integer indices in [0, 1023].
**Format: label first, locs after.** PaliGemma's pretraining puts
locs first (``<loc><loc> label``), but for our small-dataset VQA
blend that turns the LM head into a loc-emission attractor at every
``Assistant:`` position VQA targets share their first supervised
token with ~25% of all text samples, and the head collapses to
emitting ``<loc>`` regardless of the prompt. Putting the label
first (``label <locY><locX>``) means every text sample (subtask,
memory, VQA, ) starts the supervised target with a real word,
breaking the attractor. The model still learns the loc vocabulary
for the *spatial* portion of the answer; it just can't fire it as
the first generation step from a clean prompt.
Returns ``None`` for non-spatial answers (count / attribute /
spatial-relation) those keep their JSON form.
"""
point = answer.get("point")
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
try:
x, y = float(point[0]), float(point[1])
except (TypeError, ValueError):
return None
label = str(answer.get("label", "")).strip()
if not label:
return None
return f"{label} {_loc_token(y)}{_loc_token(x)}"
detections = answer.get("detections")
if isinstance(detections, list) and detections:
parts: list[str] = []
for det in detections:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
label = str(det.get("label", "")).strip()
if not label:
continue
toks = (
f"{_loc_token(y1)}{_loc_token(x1)}"
f"{_loc_token(y2)}{_loc_token(x2)}"
)
parts.append(f"{label} {toks}")
return " ; ".join(parts) if parts else None
return None
def _messages_vqa_to_loc(
messages: list[dict[str, Any]],
target_indices: list[int],
) -> list[dict[str, Any]]:
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
Each target turn whose content parses as a spatial VQA answer is
converted. Non-spatial answers and subtask / memory targets (plain
text not JSON) are left untouched. Camera-independent: VQA coords
are 01000 normalized, so no observation lookup is needed.
"""
if not target_indices:
return messages
out = list(messages)
for idx in target_indices:
if not (0 <= idx < len(out)):
continue
content = out[idx].get("content")
if not isinstance(content, str) or not content.strip():
continue
try:
answer = json.loads(content)
except (ValueError, TypeError):
continue # subtask / memory targets are plain text — skip
if not isinstance(answer, dict):
continue
loc_text = _vqa_answer_to_loc(answer)
if loc_text is not None:
out[idx] = {**out[idx], "content": loc_text}
return out
def _format_messages(
messages: list[dict[str, Any]],
target_indices: list[int] | None = None,
eos_token: str | None = None,
) -> tuple[str, list[tuple[int, int]]]:
"""Concatenate messages into the π0.5-style flat prompt.
When both ``target_indices`` and ``eos_token`` are given, the EOS
string is appended to each supervised target turn's content and the
returned span covers it so the label builder marks the EOS token
as a supervised label. That teaches the LM head where the answer
*ends*: without an EOS in the target span the model is never given a
stop signal and rambles to ``max_length`` at inference. Inference
callers omit both args (no EOS baked into the prompt the model
generates it and ``select_message`` stops on it).
Returns:
prompt: the full text the tokenizer will consume.
msg_spans: list of ``(char_start, char_end)`` covering each
message's supervised payload (content, plus the
appended EOS for target turns) within ``prompt``.
"""
targets = set(target_indices or [])
parts: list[str] = []
spans: list[tuple[int, int]] = []
cursor = 0
for i, m in enumerate(messages):
role = m.get("role", "user")
content = m.get("content", "") or ""
# Role tag + newline. The model has to learn to emit the same
# role tokens at generation time, which is fine for greedy
# decoding because the chat template is implicit in the
# supervised target span.
header = f"{role.capitalize()}: "
# A supervised target turn ends with EOS so the model learns to
# terminate; the span below covers content + EOS. Non-target
# turns (and inference) carry no EOS.
body = content + eos_token if (eos_token and i in targets) else content
# span covers the content (+ EOS) portion only — never the role
# tag — so labels are computed over the supervised payload.
full = header + body + "\n"
start = cursor + len(header)
end = start + len(body)
parts.append(full)
spans.append((start, end))
cursor += len(full)
return "".join(parts), spans
@dataclass
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
class PI052TextTokenizerStep(ProcessorStep):
"""Render messages → token ids + label mask + predict_actions flag.
No chat template; concatenates messages as
``User: ... \\nAssistant: ...`` text.
"""
tokenizer_name: str = "google/paligemma-3b-pt-224"
max_length: int = 200
padding: str = "max_length"
padding_side: str = "right"
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
interjection_dropout_prob: float = 0.0
dropout_seed: int | None = None
def __post_init__(self) -> None:
self._tokenizer: Any = None
def _ensure_tokenizer(self) -> Any:
if self._tokenizer is not None:
return self._tokenizer
from transformers import AutoTokenizer # noqa: PLC0415
self._tokenizer = register_paligemma_loc_tokens(
AutoTokenizer.from_pretrained(self.tokenizer_name)
)
return self._tokenizer
# ------------------------------------------------------------------
# Pipeline step
# ------------------------------------------------------------------
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
transition = transition.copy()
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
messages = complementary.get("messages") or []
if not messages:
# No recipe was rendered — caller will fall back to the
# plain Pi0.5 prompt path. We pass the transition through
# unmodified.
return transition
tokenizer = self._ensure_tokenizer()
# Normalized proprioceptive state (set by NormalizerProcessorStep, which
# runs before this step). Injected into low-level action prompts so the
# action expert sees proprioception, matching pi05's discretized State:.
state_all = (transition.get(TransitionKey.OBSERVATION) or {}).get(OBS_STATE)
# VQA coords are 01000 normalized (Qwen2.5-VL convention) — the
# <loc> conversion is camera-resolution-independent and needs no
# observation lookup here.
if _is_batched_messages(messages):
indices_iter = _sample_indices(complementary.get("index"), len(messages))
encoded = [
self._encode_messages(
tokenizer,
msg,
list(streams),
list(tgt_indices),
complementary,
sample_idx=int(s_idx) if s_idx is not None else None,
state_row=_state_row_at(state_all, pos),
)
for pos, (msg, streams, tgt_indices, s_idx) in enumerate(
zip(
messages,
complementary.get("message_streams") or [[] for _ in messages],
complementary.get("target_message_indices") or [[] for _ in messages],
indices_iter,
strict=False,
)
)
]
else:
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
encoded = [
self._encode_messages(
tokenizer,
messages,
list(complementary.get("message_streams") or []),
list(complementary.get("target_message_indices") or []),
complementary,
sample_idx=sample_idx,
state_row=_state_row_at(state_all, 0),
)
]
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = {
**complementary,
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
}
return transition
def _encode_messages(
self,
tokenizer: Any,
messages: list[dict[str, Any]],
message_streams: list[str | None],
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
state_row: Any = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
# Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping
# after removal.
if (
self.plan_dropout_prob
or self.memory_dropout_prob
or self.subtask_dropout_prob
or self.interjection_dropout_prob
):
messages, target_indices = self._apply_prompt_dropout(
messages,
target_indices,
complementary,
sample_idx=sample_idx,
)
# Rewrite bbox / keypoint VQA target answers from JSON to
# PaliGemma <loc> text. Coords are 01000 normalized so this is
# camera-independent.
messages = _messages_vqa_to_loc(messages, target_indices)
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
# stripping, so the spoken reply is actually tokenized and
# supervised (PaliGemma's flat prompt has no structured calls).
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
# Low-level (action-conditioning) samples get the discretized state
# appended to their user message, mirroring pi05's
# "..., State: {256-bin};" so the action expert sees proprioception.
# Higher-level text streams (subtask/memory generation) stay state-free.
if state_row is not None and any(s == "low_level" for s in message_streams):
state_str = discretize_state_str(state_row)
for m in reversed(messages):
if m.get("role") == "user":
base = _content_to_text(m.get("content", ""))
m["content"] = f"{base}, State: {state_str};"
break
# Append EOS to supervised target turns so the LM head learns to
# stop (the span covers it → it becomes a supervised label).
prompt, spans = _format_messages(
messages, target_indices, getattr(tokenizer, "eos_token", None)
)
encoded = tokenizer(
prompt,
max_length=self.max_length,
padding=self.padding,
truncation=True,
return_tensors="pt",
return_offsets_mapping=True,
padding_side=self.padding_side,
)
input_ids = encoded["input_ids"][0]
attention_mask = encoded["attention_mask"][0].bool()
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
# Build label mask: -100 everywhere except over supervised
# target message char ranges.
labels = torch.full_like(input_ids, fill_value=-100)
for idx in target_indices:
if idx >= len(spans):
continue
char_start, char_end = spans[idx]
for token_pos in range(input_ids.shape[0]):
if not attention_mask[token_pos]:
continue
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
if tok_end <= char_start or tok_start >= char_end:
continue
labels[token_pos] = input_ids[token_pos]
# Scan ALL message streams (not just targets): the
# ``low_level_execution`` recipe drops ``target: true`` on
# the assistant to avoid trivial copy-from-user text-CE; the
# flow loss still needs to fire, gated by ``stream: low_level``.
predict_actions = torch.tensor(
bool(any(s == "low_level" for s in message_streams)),
dtype=torch.bool,
)
return input_ids, attention_mask, labels, predict_actions, prompt
# ------------------------------------------------------------------
# Per-component prompt dropout (Pi0.7 §V.E)
# ------------------------------------------------------------------
def _apply_prompt_dropout(
self,
messages: list[dict[str, Any]],
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
) -> tuple[list[dict[str, Any]], list[int]]:
"""Drop messages classified as plan/memory/subtask context.
Targets are *never* dropped (they're the supervised payload).
Re-maps target_indices to the new positions after drops.
"""
import random # noqa: PLC0415
seed = self.dropout_seed
if seed is None:
# Canonical row-index key set by ``BatchProcessor`` /
# ``render_messages_processor``. Falling back to other
# keys silently gave every sample seed=0 → identical
# dropout pattern across the whole epoch.
seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
try:
if hasattr(seed_src, "item"):
seed_src = seed_src.item()
seed = int(seed_src)
except (TypeError, ValueError):
seed = 0
rng = random.Random(seed)
keep_indices: list[int] = []
for idx, msg in enumerate(messages):
if idx in target_indices:
keep_indices.append(idx)
continue
kind = _classify_for_dropout(msg)
prob = {
"plan": self.plan_dropout_prob,
"memory": self.memory_dropout_prob,
"subtask": self.subtask_dropout_prob,
"interjection": self.interjection_dropout_prob,
}.get(kind, 0.0)
if prob > 0.0 and rng.random() < prob:
continue
keep_indices.append(idx)
# Build remap and apply
new_messages = [messages[i] for i in keep_indices]
old_to_new = {old: new for new, old in enumerate(keep_indices)}
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
return new_messages, new_targets
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
content = message.get("content")
if isinstance(content, list):
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
content = " ".join(text_parts)
elif content is None:
return None
elif not isinstance(content, str):
return None
s = content.strip()
if s.startswith("Plan:") or s.startswith("Previous plan"):
return "plan"
if s.startswith("Memory:") or s.startswith("Previous memory"):
return "memory"
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
return "subtask"
return None
+2 -387
View File
@@ -14,28 +14,18 @@
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING
import torch
from torch import Tensor, nn
from torch.nn import functional as F # noqa: N812
from torch import nn
from lerobot.utils.import_utils import _transformers_available
# Default PaliGemma SigLIP input resolution. Mirrors
# ``pi05.configuration_pi05.DEFAULT_IMAGE_SIZE``; duplicated as a plain constant
# to avoid importing the pi05 package here (which would create an import cycle:
# pi_gemma -> pi05.__init__ -> modeling_pi05 -> pi_gemma).
DEFAULT_IMAGE_SIZE = 224
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaConfig,
@@ -59,8 +49,6 @@ else:
GradientCheckpointingLayer = None
BaseModelOutputWithPast = None
create_causal_mask = None
CONFIG_MAPPING = None
modeling_gemma = None
def _gated_residual(
@@ -287,8 +275,6 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
if causal_mask is not None and torch.is_floating_point(causal_mask):
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
@@ -381,374 +367,3 @@ __all__ = [
"PaliGemmaModelWithPiGemma",
"PaliGemmaForConditionalGenerationWithPiGemma",
]
# PI0.5 / PI052 dual-expert backbone: generic PaliGemma + Gemma action-expert
# transformer machinery used by the pi052 policy. GemmaVariantConfig is openpi's
# width/depth variant config (renamed from GemmaConfig to avoid clashing with
# transformers' GemmaConfig).
def sdpa_attention_forward(
module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
):
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
``torch.nn.functional.scaled_dot_product_attention``.
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
bias masks (the FA backend only accepts causal/sliding-window). On
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
"""
n_rep = module.num_key_value_groups
if n_rep > 1:
key = key.repeat_interleave(n_rep, dim=1)
value = value.repeat_interleave(n_rep, dim=1)
if attention_mask is not None and attention_mask.dtype != query.dtype:
attention_mask = attention_mask.to(dtype=query.dtype)
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout if module.training else 0.0,
is_causal=False,
scale=scaling,
)
return attn_output.transpose(1, 2).contiguous(), None
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
att_output, _ = sdpa_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
class GemmaVariantConfig: # see openpi `gemma.py: Config`
"""Configuration for Gemma model variants."""
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
self.width = width
self.depth = depth
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
def get_gemma_config(variant: str) -> GemmaVariantConfig: # see openpi `gemma.py: get_config`
"""Returns config for specified gemma variant."""
if variant == "gemma_300m":
return GemmaVariantConfig(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
elif variant == "gemma_2b":
return GemmaVariantConfig(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
else:
raise ValueError(f"Unknown variant: {variant}")
class PaliGemmaWithExpertModel(
nn.Module
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
"""PaliGemma model with action expert for PI05."""
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"lm_head",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
# OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the
# Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base
# was trained in this regime, so scaling image features here over-scales them ~45x and
# breaks the pretrained vision-language alignment. Keep image features un-normalized.
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
# final norm
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms,
inputs_embeds,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values
-1
View File
@@ -1 +0,0 @@
../../../../docs/source/policy_vla_jepa_README.md
-23
View File
@@ -1,23 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"make_vla_jepa_pre_post_processors",
]
@@ -1,337 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions
@@ -1,154 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
gripper_dim: int = 6
gripper_threshold: float = 0.5
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
@@ -1,629 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if reinitialized:
logging.warning(
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model
@@ -1,155 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes a gripper dimension after unnormalization.
Maps continuous value to {-1, 1}: > threshold -1, <= threshold 1 (matches starVLA convention).
Only applied when action has more dimensions than gripper_dim.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -1,117 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)
@@ -1,418 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)
+3
View File
@@ -175,6 +175,9 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
return actions
@ProcessorStepRegistry.register("relative_actions_processor")
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
@@ -50,17 +50,7 @@ class RenderMessagesStep(ProcessorStep):
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
rendered = _fallback_low_level_render(complementary_data.get("task"))
if rendered is None:
return transition
new_transition = transition.copy()
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
if _is_batched_language(persistent) or _is_batched_language(events):
return self._call_batch(transition, complementary_data, persistent, events)
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
@@ -77,147 +67,18 @@ class RenderMessagesStep(ProcessorStep):
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
rendered = _fallback_low_level_render(complementary_data.get("task"))
if rendered is None:
return None
return None
new_transition = transition.copy()
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def _call_batch(
self,
transition: EnvTransition,
complementary_data: dict[str, Any],
persistent_batch: list,
events_batch: list,
) -> EnvTransition | None:
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
batch_size = max(len(persistent_batch), len(events_batch))
messages: list[list[dict[str, Any]]] = []
message_streams: list[list[str | None]] = []
target_message_indices: list[list[int]] = []
keep_indices: list[int] = []
for i in range(batch_size):
rendered = render_sample(
recipe=self.recipe,
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
events=events_batch[i] if i < len(events_batch) else [],
t=_batch_value(timestamp, i),
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
task=_batch_value(complementary_data.get("task"), i),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
if rendered is None:
continue
keep_indices.append(i)
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
if not messages:
return None
new_transition = (
_select_batch_indices(transition, keep_indices)
if len(keep_indices) != batch_size
else transition.copy()
)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data["messages"] = messages
new_complementary_data["message_streams"] = message_streams
new_complementary_data["target_message_indices"] = target_message_indices
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return _scalar(value[0])
return value
def _is_batched_language(value: Any) -> bool:
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
def _batch_value(value: Any, index: int) -> Any:
if value is None:
return None
if isinstance(value, list):
return value[index]
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
return _scalar(value[index])
return _scalar(value)
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
selected = transition.copy()
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
data = selected.get(key)
if isinstance(data, dict):
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
action = selected.get(TransitionKey.ACTION)
if action is not None:
selected[TransitionKey.ACTION] = _select_value(action, indices)
return selected
def _select_value(value: Any, indices: list[int]) -> Any:
if isinstance(value, list) and len(value) >= len(indices):
return [value[i] for i in indices]
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
return value.index_select(0, value.new_tensor(indices).long())
return value
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
"""Keep action-only samples trainable when no recipe branch matches."""
if hasattr(task, "item"):
task = task.item()
if isinstance(task, list):
messages = []
message_streams = []
target_message_indices = []
for t in task:
rendered = _fallback_low_level_render(t)
if rendered is None:
return None
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
return {
"messages": messages,
"message_streams": message_streams,
"target_message_indices": target_message_indices,
}
if not isinstance(task, str) or not task:
return None
return {
"messages": [{"role": "user", "content": task}],
"message_streams": ["low_level"],
"target_message_indices": [],
}
+16 -31
View File
@@ -32,7 +32,6 @@ import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
from lerobot.utils.constants import (
ACTION_CODE_TOKEN_MASK,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -413,15 +412,14 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# During inference, no action is available, skip tokenization
return new_transition
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
tokens, mask, code_mask = self._tokenize_action(action)
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
# Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None:
complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
@@ -432,7 +430,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
@@ -461,7 +459,6 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# The fast tokenizer expects action data and returns token IDs
tokens_list = []
masks_list = []
code_masks_list = []
for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
@@ -479,26 +476,19 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
if tokens.dim() > 1:
tokens = tokens.flatten()
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
bos_id = self._paligemma_tokenizer.bos_token_id
prompt_tokens = torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
)
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
code_start = 1 + len(prompt_tokens)
code_end = code_start + len(action_code_tokens)
# add bos
tokens = torch.cat(
[
torch.tensor([bos_id], device=action.device),
prompt_tokens,
action_code_tokens,
end_tokens,
torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
]
)
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
code_mask[code_start:code_end] = True
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
@@ -507,49 +497,44 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
)
tokens = tokens[: self.max_action_tokens]
code_mask = code_mask[: self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
pad_len = self.max_action_tokens - len(tokens)
mask = torch.cat(
[
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
torch.zeros(
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
),
]
)
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
# Pad tokens with zeros
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
tokens_list.append(tokens)
masks_list.append(mask)
code_masks_list.append(code_mask)
# Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample
if single_sample:
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)
code_masks_batch = code_masks_batch.squeeze(0)
# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
code_masks_batch = code_masks_batch.to(device)
return tokens_batch, masks_batch, code_masks_batch
return tokens_batch, masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
This method is not used since we override __call__.
Required by ActionProcessorStep ABC.
"""
tokens, _, _ = self._tokenize_action(action)
tokens, _ = self._tokenize_action(action)
return tokens
def get_config(self) -> dict[str, Any]:
-2
View File
@@ -20,14 +20,12 @@ from .factory import (
make_reward_pre_post_processors as make_reward_pre_post_processors,
)
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
__all__ = [
# Configuration classes
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
"TOPRewardConfig",
# Base class
+2 -16
View File
@@ -25,7 +25,6 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig
@@ -39,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
Args:
name: The name of the reward model. Supported names are "reward_classifier",
"sarm", "robometer", "topreward".
"sarm", "topreward".
Returns:
The reward model class corresponding to the given name.
@@ -55,10 +54,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "robometer":
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
return RobometerRewardModel
elif name == "topreward":
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
@@ -79,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
Args:
reward_type: The type of the reward model. Supported types include
"reward_classifier", "sarm", "robometer", "topreward".
"reward_classifier", "sarm", "topreward".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -92,8 +87,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RewardClassifierConfig(**kwargs)
elif reward_type == "sarm":
return SARMConfig(**kwargs)
elif reward_type == "robometer":
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
else:
@@ -175,13 +168,6 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(reward_cfg, RobometerConfig):
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
return make_robometer_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, TOPRewardConfig):
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
-19
View File
@@ -1,19 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_robometer import RobometerConfig
from .modeling_robometer import RobometerRewardModel
from .processor_robometer import make_robometer_pre_post_processors
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]
@@ -1,320 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute per-frame Robometer progress and success curves for a LeRobot dataset.
For each episode, builds per-frame sub-samples using the frame-steps
strategy from the Robometer eval server: for each original frame ``t``,
linspace-subsample ``[0, t]`` into ``K`` frames (default 4, matching
``NUM_SUBSAMPLED_FRAMES`` in the eval server), run one forward through
the Robometer processor + model, and keep the last-frame progress value.
All sub-samples are the same size ``K`` so they batch cleanly.
The parquet uses the same schema as SARM's
:mod:`lerobot.rewards.sarm.compute_rabc_weights` so existing consumers
:class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
``progress_sparse``) and the progress-overlay script in
``examples/dataset/create_progress_videos.py`` work without modification.
Usage:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes with batching
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
from lerobot.types import TransitionKey
DEFAULT_OUTPUT_FILENAME = "robometer_progress.parquet"
# Upstream Robometer eval server uses K=4 for frame-steps sub-samples.
DEFAULT_NUM_SUBSAMPLED_FRAMES = 4
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
"""Read ``reward_model_path`` from parquet metadata if available."""
if not parquet_path.exists():
return None
try:
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
if metadata and b"reward_model_path" in metadata:
return metadata[b"reward_model_path"].decode()
except Exception: # nosec B110
return None
return None
def _resolve_task(sample: dict[str, Any], default: str) -> str:
"""Best-effort task extraction from a dataset sample."""
task = sample.get("task")
if isinstance(task, str) and task:
return task
return default
def _build_subsample_indices(num_frames: int, num_subsampled_frames: int) -> list[np.ndarray]:
"""Frame-steps linspace expansion.
For each ``t in [0, num_frames - 1]`` returns ``num_subsampled_frames``
indices from ``np.linspace(0, t, num_subsampled_frames)`` the first
and last frames are always included. Each entry is a fixed-size array
so the model can batch them.
"""
return [np.linspace(0, t, num_subsampled_frames).round().astype(np.int64) for t in range(num_frames)]
def compute_robometer_progress(
dataset_repo_id: str,
reward_model_path: str,
output_path: str | None = None,
device: str = "cuda",
batch_size: int = 32,
num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES,
episodes: list[int] | None = None,
image_key: str | None = None,
) -> Path:
"""Run Robometer over a dataset and write per-frame progress + success."""
logging.info(f"Loading Robometer: {reward_model_path}")
config = RobometerConfig(pretrained_path=reward_model_path, device=device)
if image_key is not None:
config.image_key = image_key
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
model.to(device).eval()
encoder = RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=num_subsampled_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
)
image_key = config.image_key
logging.info(f"Loading dataset: {dataset_repo_id}")
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
logging.info(f"Processing {len(episode_indices)} episode(s)")
all_index: list[int] = []
all_episode: list[int] = []
all_frame: list[int] = []
all_progress: list[float] = []
for episode_idx in tqdm(episode_indices, desc="Episodes"):
ep = dataset.meta.episodes[episode_idx]
ep_start = int(ep["dataset_from_index"])
ep_end = int(ep["dataset_to_index"])
num_frames = ep_end - ep_start
if num_frames <= 0:
continue
first_sample = dataset[ep_start]
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
ep_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
sub_indices = _build_subsample_indices(num_frames, num_subsampled_frames)
progress_per_frame = np.zeros(num_frames, dtype=np.float32)
for start in tqdm(range(0, num_frames, batch_size), desc=f" Ep {episode_idx}", leave=False):
end = min(start + batch_size, num_frames)
frames_batch = torch.stack([ep_frames[sub_indices[i]] for i in range(start, end)])
transition = {
TransitionKey.OBSERVATION: {image_key: frames_batch},
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
}
encoded = encoder(transition)
obs = encoded[TransitionKey.OBSERVATION]
batch = {
key: value.to(device) if isinstance(value, torch.Tensor) else value
for key, value in obs.items()
}
with torch.no_grad():
rewards = model.compute_reward(batch)
progress_per_frame[start:end] = rewards.cpu().numpy()
for local in range(num_frames):
all_index.append(ep_start + local)
all_episode.append(episode_idx)
all_frame.append(local)
all_progress.append(float(progress_per_frame[local]))
if device.startswith("cuda"):
torch.cuda.empty_cache()
table = pa.table(
{
"index": np.asarray(all_index, dtype=np.int64),
"episode_index": np.asarray(all_episode, dtype=np.int64),
"frame_index": np.asarray(all_frame, dtype=np.int64),
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
}
).replace_schema_metadata({b"reward_model_path": reward_model_path.encode()})
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out)
logging.info(f"Saved {len(table)} frame values to {out}")
progress_arr = np.asarray(all_progress, dtype=np.float32)
if progress_arr.size:
logging.info(
f"Progress: mean={float(progress_arr.mean()):.4f}, "
f"std={float(progress_arr.std()):.4f}, "
f"min={float(progress_arr.min()):.4f}, "
f"max={float(progress_arr.max()):.4f}"
)
return out
def main():
parser = argparse.ArgumentParser(
description="Compute per-frame Robometer progress curves for RA-BC weighting.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes, smaller batches for memory-constrained GPUs
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
""",
)
parser.add_argument(
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
)
parser.add_argument(
"--reward-model-path", type=str, default=None, help="Robometer checkpoint repo id or local path."
)
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
parser.add_argument(
"--batch-size", type=int, default=32, help="Sub-samples per Qwen forward (default: 32)."
)
parser.add_argument(
"--num-subsampled-frames",
type=int,
default=DEFAULT_NUM_SUBSAMPLED_FRAMES,
help=f"Frames per sub-sample (default: {DEFAULT_NUM_SUBSAMPLED_FRAMES}, matches eval server).",
)
parser.add_argument(
"--episodes", type=int, nargs="+", default=None, help="Process only these episode indices."
)
parser.add_argument(
"--image-key", type=str, default=None, help="Image observation key (default: from config)."
)
parser.add_argument(
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
reward_model_path = args.reward_model_path
if reward_model_path is None:
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
parquet_path = Path(temp_dataset.root) / DEFAULT_OUTPUT_FILENAME
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
if reward_model_path:
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
else:
raise ValueError(
"--reward-model-path is required (no existing parquet with model metadata found)."
)
output_path = compute_robometer_progress(
dataset_repo_id=args.dataset_repo_id,
reward_model_path=reward_model_path,
output_path=args.output_path,
device=args.device,
batch_size=args.batch_size,
num_subsampled_frames=args.num_subsampled_frames,
episodes=args.episodes,
image_key=args.image_key,
)
print(f"\nRobometer progress saved to: {output_path}")
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = DEFAULT_OUTPUT_FILENAME
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(output_path),
path_in_repo=hub_path,
repo_id=args.dataset_repo_id,
repo_type="dataset",
)
print(
"Successfully uploaded to: "
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
)
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
print(" rabc_head_mode: sparse")
else:
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: {output_path}")
print(" rabc_head_mode: sparse")
if __name__ == "__main__":
main()
@@ -1,158 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.rewards import RewardModelConfig
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoTokenizer
else:
AutoConfig = None # type: ignore[assignment]
AutoTokenizer = None # type: ignore[assignment]
# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time.
# The order is part of the data contract: upstream resized ``embed_tokens``
# after adding these tokens in this exact order, so changing the set or order
# would silently misalign the saved embedding rows with their token ids.
# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream
# heads (never read at inference) but still occupy rows the checkpoint expects.
ROBOMETER_SPECIAL_TOKENS = (
"<|split_token|>",
"<|reward_token|>",
"<|pref_token|>",
"<|sim_token|>",
"<|prog_token|>",
)
@RewardModelConfig.register_subclass("robometer")
@dataclass
class RobometerConfig(RewardModelConfig):
"""Configuration for the Robometer reward model."""
pretrained_path: str | None = "lerobot/Robometer-4B"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
reward_output: str = "progress" # "progress" or "success"
success_threshold: float = 0.5
license: str | None = "apache-2.0"
tags: list[str] | None = field(
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
)
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
torch_dtype: str = "bfloat16"
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
average_temporal_patches: bool = True
frame_pooling: str = "mean" # "mean" | "boundary" | "attention"
frame_pooling_attn_temperature: float = 1.0
progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete"
progress_discrete_bins: int = 10
# Serialised Qwen backbone config (post-resize). Always populated by
# ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it
# is non-empty after construction. Saved into ``config.json`` automatically
# by the base ``_save_pretrained``.
vlm_config: dict[str, Any] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self) -> None:
super().__post_init__()
if self.reward_output not in {"progress", "success"}:
raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}")
if self.max_frames is not None and self.max_frames < 1:
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
if self.frame_pooling not in {"mean", "boundary", "attention"}:
raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}")
if self.frame_pooling_attn_temperature <= 0:
raise ValueError("frame_pooling_attn_temperature must be > 0")
if self.progress_loss_type not in {"l1", "l2", "discrete"}:
raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}")
if self.use_per_frame_progress_token and not self.use_multi_image:
raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True")
if self.image_key not in self.input_features:
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
# Deterministically populate ``vlm_config`` so it is non-empty after
# construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives
# ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize
# vocab the published ``Robometer-4B`` checkpoint was saved with.
if not self.vlm_config:
require_package("transformers", extra="robometer")
vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict()
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
text_config = vlm.get("text_config")
if not isinstance(text_config, dict):
raise ValueError(
f"Backbone config for {self.base_model_id!r} has no nested `text_config`; "
"Robometer expects a Qwen-VL-style config."
)
text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)
self.vlm_config = vlm
@property
def use_discrete_progress(self) -> bool:
"""Whether the progress head outputs distribution logits over bins."""
return self.progress_loss_type.lower() == "discrete"
@property
def vlm_backbone_config(self):
"""Reconstruct the Qwen backbone config from :attr:`vlm_config`."""
require_package("transformers", extra="robometer")
config_dict = deepcopy(self.vlm_config)
model_type = config_dict.pop("model_type", None)
if model_type is None:
raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config")
return AutoConfig.for_model(model_type, **config_dict)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> None:
return None
@property
def reward_delta_indices(self) -> None:
return None
def validate_features(self) -> None:
if self.image_key not in self.input_features:
raise ValueError(f"Robometer requires image input feature {self.image_key!r}")
@@ -1,481 +0,0 @@
# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons.
Paper: https://arxiv.org/abs/2603.02115
Project: https://robometer.github.io
Original code: https://github.com/aliang8/robometer
Model: https://huggingface.co/robometer/Robometer-4B
Robometer is a general-purpose, video-language-input reward model built on
``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction
objective:
- A frame-level progress loss anchoring reward magnitude on expert data.
- A trajectory-comparison preference loss imposing global ordering constraints
across trajectories sharing the same instruction.
To support downstream RL it also predicts a frame-level binary success. The
training prompt inserts three learnable tokens:
- ``<|prog_token|>`` after each frame to read per-frame progress and success.
- ``<|pref_token|>`` at the end to read pairwise preference (training-only).
- ``<|split_token|>`` between two trajectories in preference samples
(training-only).
Progress is modeled as a categorical distribution over ``progress_discrete_bins``
uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate
is recovered as the softmax-weighted mean of those centers see
:func:`convert_bins_to_continuous`.
This LeRobot port is **inference-only**: the preference head is preserved in
the state dict for byte-equivalence with the published ``Robometer-4B``
checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`,
which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd
success probability depending on :attr:`RobometerConfig.reward_output`.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import torch
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModelForImageTextToText
else:
AutoModelForImageTextToText = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
# Namespace for Robometer's pre-encoded Qwen-VL observation tensors.
ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer."
ROBOMETER_QWEN_INPUT_KEYS = (
"input_ids",
"attention_mask",
"pixel_values",
"pixel_values_videos",
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"mm_token_type_ids",
)
ROBOMETER_METADATA_KEYS = (
"prog_token_id",
"vision_start_token_id",
"vision_end_token_id",
"video_merge_size",
)
ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS
def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor:
"""Collapse per-bin logits into a single value in ``[0, 1]``.
The discrete progress head outputs ``num_bins`` logits per frame. Bins are
evenly spaced centers in ``[0, 1]``; the continuous prediction is the
softmax-weighted mean of those centers.
"""
bin_probs = torch.softmax(bin_logits, dim=-1)
num_bins = bin_logits.shape[-1]
bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype)
return (bin_probs * bin_centers).sum(dim=-1)
def _squeeze_last_safe(x: Tensor) -> Tensor:
"""Drop a trailing singleton dim only when present."""
return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x
def _torch_dtype(name: str) -> torch.dtype:
dtype = getattr(torch, name, None)
if isinstance(dtype, torch.dtype):
return dtype
raise ValueError(f"Unknown torch dtype: {name!r}")
class RobometerPredictionHead(nn.Sequential):
"""Small MLP head used for Robometer's progress / success / preference outputs."""
def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None:
layers: list[nn.Module] = [
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_size),
]
if with_sigmoid:
layers.append(nn.Sigmoid())
super().__init__(*layers)
def decode_progress_outputs(
progress_logits: Tensor | None,
success_logits: Tensor | None,
*,
is_discrete_mode: bool,
) -> dict[str, list[list[float]]]:
"""Decode RBM head outputs into per-frame floats.
Args:
progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities.
is_discrete_mode: if True the progress logits get a softmax over bins
and are projected onto bin centers via :func:`convert_bins_to_continuous`.
Returns:
Dict with ``progress_pred`` and ``success_probs``, each a list of
length ``B`` of per-frame float lists.
"""
progress_pred: list[list[float]] = []
success_probs: list[list[float]] = []
if progress_logits is not None:
for sample_logits in progress_logits:
if is_discrete_mode:
continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu())
progress_pred.append(continuous.flatten().tolist())
else:
progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist())
if success_logits is not None:
for sample_logits in success_logits:
success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist())
return {"progress_pred": progress_pred, "success_probs": success_probs}
class RobometerRewardModel(PreTrainedRewardModel):
"""Robometer (RBM) reward model — inference-only LeRobot port.
Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three
prediction heads from the paper (progress, success, preference). At
inference time only the progress and success heads are queried; the
preference head is kept on the module so the published ``Robometer-4B``
safetensors load unchanged.
"""
name = "robometer"
config_class = RobometerConfig
def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None:
require_package("transformers", extra="robometer")
super().__init__(config)
self.config = config
# Two backbone-build paths (EO-1 style, branched on ``pretrained_path``):
#
# - Fresh training (``pretrained_path is None``): download the base
# Qwen weights and resize the embed table to match
# ``vlm_config.text_config.vocab_size`` — populated deterministically
# in ``RobometerConfig.__post_init__`` as
# ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``
#
# - Loading a saved checkpoint (``pretrained_path`` is set): rebuild
# the empty architecture from ``vlm_config`` via
# ``AutoModelForImageTextToText.from_config`` so the subsequent
# ``model.safetensors`` load is a direct fill of the right shape —
# no redundant Qwen weight download.
torch_dtype = _torch_dtype(config.torch_dtype)
if config.pretrained_path is None:
self.model = AutoModelForImageTextToText.from_pretrained(
config.base_model_id,
dtype=torch_dtype,
trust_remote_code=True,
)
target_vocab = config.vlm_config["text_config"]["vocab_size"]
self.model.resize_token_embeddings(target_vocab)
else:
self.model = AutoModelForImageTextToText.from_config(
config.vlm_backbone_config,
dtype=torch_dtype,
trust_remote_code=True,
)
# All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`.
# Falls back to the top-level `hidden_size` so future non-multimodal
# variants would still resolve.
backbone_config = self.model.config
text_config = getattr(backbone_config, "text_config", None)
hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None
if hidden_size is None:
hidden_size = getattr(backbone_config, "hidden_size", None)
if hidden_size is None:
raise AttributeError(
f"Could not infer hidden_size from backbone config of {config.base_model_id}"
)
hidden_dim = int(hidden_size)
# Robometer's three prediction heads + frame-pool attention.
progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1
self.progress_head = RobometerPredictionHead(
hidden_dim,
progress_output,
dropout=dropout,
with_sigmoid=not config.use_discrete_progress,
)
self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False)
# Match the dtype of the loaded base model so weight loading is a no-op cast.
model_dtype = next(self.model.parameters()).dtype
self.progress_head.to(dtype=model_dtype)
self.preference_head.to(dtype=model_dtype)
self.success_head.to(dtype=model_dtype)
self.frame_pool_attn.to(dtype=model_dtype)
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
inputs = {
key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"]
for key in ROBOMETER_INPUT_KEYS
if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch
}
if "input_ids" not in inputs:
raise KeyError(
f"Robometer batch missing pre-encoded inputs (expected "
f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the "
"RobometerEncoderProcessorStep ran before `compute_reward`."
)
device = next(self.model.parameters()).device
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
self.eval()
with torch.no_grad():
progress_logits, success_logits = self._compute_rbm_logits(inputs)
decoded = decode_progress_outputs(
progress_logits,
success_logits,
is_discrete_mode=self.config.use_discrete_progress,
)
values = (
decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"]
)
rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values])
if self.config.reward_output == "success":
rewards = (rewards > self.config.success_threshold).float()
else:
# Match upstream Robometer's ``extract_rewards_from_output``: per-frame
# progress predictions are clamped to ``[0, 1]`` before being returned.
rewards = rewards.clamp(0.0, 1.0)
return rewards.to(self.config.device or "cpu")
def _compute_rbm_logits(
self,
inputs: dict[str, Any],
) -> tuple[Tensor, Tensor]:
"""Run the Qwen3-VL backbone and apply Robometer's heads.
``inputs`` is the encoded batch produced by
:class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well
as Robometer-specific metadata (``prog_token_id``,
``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``)
the metadata is popped here so the rest can be forwarded straight to
the Qwen model.
Returns ``(progress_logits, success_logits)``. Shapes:
- ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
- ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time).
"""
prog_token_id = inputs.pop("prog_token_id", None)
vision_start_token_id = inputs.pop("vision_start_token_id", None)
vision_end_token_id = inputs.pop("vision_end_token_id", None)
video_merge_size = inputs.pop("video_merge_size", 14)
# Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the
# full hidden-state tuple and take the last layer. This matches the
# `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main).
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
hidden_state = (
outputs.hidden_states[-1]
if getattr(outputs, "hidden_states", None)
else outputs.last_hidden_state
)
input_ids = inputs["input_ids"]
if self.config.use_per_frame_progress_token:
if prog_token_id is None:
raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)")
return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id)
if self.config.use_multi_image:
if vision_start_token_id is None or vision_end_token_id is None:
raise KeyError(
"`vision_start_token_id` / `vision_end_token_id` missing in batch "
"(run RobometerEncoderProcessorStep first)"
)
return self._process_multi_image_frames(
hidden_state,
input_ids,
start_id=vision_start_token_id,
end_id=vision_end_token_id,
)
video_grid_thw = inputs.get("video_grid_thw")
if video_grid_thw is None:
raise ValueError("video_grid_thw is required for video-mode Robometer inference")
if vision_start_token_id is None:
raise KeyError("`vision_start_token_id` missing in batch")
return self._process_video_frames(
hidden_state,
input_ids,
video_grid_thw,
start_id=vision_start_token_id,
merge_size=video_merge_size,
)
def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]:
"""Apply progress + success heads to a tensor of frame embeddings."""
progress_out = self.progress_head(frame_embeddings)
progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out)
success = _squeeze_last_safe(self.success_head(frame_embeddings))
return progress, success
def _process_token_extraction(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
prog_token_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success from ``<|prog_token|>`` positions."""
token_mask = input_ids == prog_token_id
batch_indices, positions = token_mask.nonzero(as_tuple=True)
if positions.numel() == 0:
raise ValueError("`<|prog_token|>` not found in any sequence")
per_sample_hidden = [
hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0])
]
progress_list, success_list = [], []
for embeddings in per_sample_hidden:
if embeddings.shape[0] == 0:
raise ValueError("`<|prog_token|>` missing in a sequence")
progress, success = self._apply_heads_to_hidden_states(embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _process_multi_image_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
start_id: int,
end_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in multi-image mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
frame_embeddings = self._extract_hidden_states_from_token_pairs(
seq_hidden, seq_ids, start_id, end_id
)
progress, success = self._apply_heads_to_hidden_states(frame_embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _extract_hidden_states_from_token_pairs(
self,
hidden_state: Tensor,
input_ids: Tensor,
start_id: int,
end_id: int,
) -> Tensor:
start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0]
end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
if start_positions.numel() != end_positions.numel():
raise ValueError(
f"Mismatched vision token counts: {start_positions.numel()} start vs "
f"{end_positions.numel()} end"
)
frames: list[Tensor] = []
for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True):
if start >= end:
raise ValueError(f"Invalid vision token pair: start={start} end={end}")
patch_tokens = hidden_state[start + 1 : end]
if patch_tokens.shape[0] == 0:
frames.append((hidden_state[start] + hidden_state[end]) / 2.0)
continue
pooling = self.config.frame_pooling
if pooling == "mean":
frames.append(patch_tokens.mean(dim=0))
elif pooling == "boundary":
frames.append(patch_tokens[-1])
else: # attention
scores = (
self.frame_pool_attn(patch_tokens).squeeze(-1)
/ self.config.frame_pooling_attn_temperature
)
weights = torch.softmax(scores, dim=0).unsqueeze(-1)
frames.append((weights * patch_tokens).sum(dim=0))
return torch.stack(frames)
def _process_video_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
video_grid_thw: Tensor,
*,
start_id: int,
merge_size: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in video mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist())
tokens_per_frame = (h_dim * w_dim) // (merge_size**2)
cursor = start_positions[0].item()
frame_embeddings: list[Tensor] = []
for _ in range(t_dim):
if self.config.average_temporal_patches:
patch = seq_hidden[cursor : cursor + tokens_per_frame]
frame_embeddings.append(patch.mean(dim=0))
else:
frame_embeddings.append(seq_hidden[cursor + tokens_per_frame])
cursor += tokens_per_frame
stacked = torch.stack(frame_embeddings)
progress, success = self._apply_heads_to_hidden_states(stacked)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
@@ -1,338 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Robometer pre/post processing pipelines."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from PIL import Image
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
policy_action_to_transition,
)
from lerobot.rewards.robometer.configuration_robometer import (
ROBOMETER_SPECIAL_TOKENS,
RobometerConfig,
)
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor
else:
AutoProcessor = None
PROGRESS_PROMPT = (
"The task for the robot is '{task}'. Given the trajectory video, predict "
"the task progress at each frame, how far along the robot is towards "
"completing the task, a float between 0 and 1, where 0 is the starting "
"state and 1 is when the task is completed. If the robot is not "
"performing the same task, predict 0 progress."
)
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
if frames.ndim != 4:
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
if frames.dtype != np.uint8:
frames = np.clip(frames, 0, 255).astype(np.uint8)
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array."""
if max_frames is not None:
video = video[-max_frames:]
if video.shape[1] in (1, 3):
video = video.permute(0, 2, 3, 1)
elif video.shape[-1] not in (1, 3):
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
array = video.detach().cpu().numpy()
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
array = array * 255.0
return np.clip(array, 0, 255).astype(np.uint8)
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
if task is None:
task = default
if task is None:
raise KeyError("Robometer expected a task description in complementary data")
if isinstance(task, str):
return [task] * batch_size
if isinstance(task, tuple):
task = list(task)
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}")
if len(task) == 1 and batch_size > 1:
return task * batch_size
if len(task) != batch_size:
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
return task
@dataclass
@ProcessorStepRegistry.register(name="robometer_encoder")
class RobometerEncoderProcessorStep(ProcessorStep):
"""Encode raw frames + task into Qwen-VL tensors for the Robometer model.
Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and
registers Robometer's special tokens on the tokenizer. The matching
embedding resize happens model-side in
:meth:`RobometerRewardModel.__init__`.
At call time the step reads:
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
- ``complementary_data[task_key]``: a string or list of strings.
and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}<name>"]`` for:
- the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``,
``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ...
- Robometer-specific token ids consumed by the model heads:
``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``,
``video_merge_size``.
"""
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
max_length: int = 1024
_processor: Any = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
require_package("transformers", extra="robometer")
require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils")
self._processor = AutoProcessor.from_pretrained(
self.base_model_id,
trust_remote_code=True,
do_sample_frames=False,
padding_side="right",
)
# Register Robometer's special tokens on the tokenizer. The matching
# embedding resize happens model-side in `RobometerRewardModel.__init__`.
tokenizer = self._processor.tokenizer
# Qwen tokenizers may not define a pad token, but batched prompts/videos
# require padding, so reuse EOS as the padding token.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
for token in ROBOMETER_SPECIAL_TOKENS:
if token not in tokenizer.get_vocab():
tokenizer.add_special_tokens({"additional_special_tokens": [token]})
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if not isinstance(observation, dict):
raise ValueError("RobometerEncoderProcessorStep requires an observation dict")
if self.image_key not in observation:
raise KeyError(f"Robometer expected image key {self.image_key!r} in observation")
frames = observation[self.image_key]
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
if tensor.ndim == 4:
tensor = tensor.unsqueeze(1)
elif tensor.ndim != 5:
raise ValueError(
f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
)
batch_size = tensor.shape[0]
tasks = _expand_tasks(
complementary.get(self.task_key, self.default_task),
batch_size=batch_size,
default=self.default_task,
)
samples = [
(_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size)
]
encoded = self.encode_samples(samples)
new_observation = dict(observation)
for key, value in encoded.items():
new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]:
"""Run the Qwen-VL processor on a list of ``(frames, task)`` samples."""
from qwen_vl_utils import process_vision_info
conversations = [self._build_conversation(frames, task) for frames, task in samples]
texts = [
self._processor.apply_chat_template(
msg,
tokenize=False,
add_generation_prompt=False,
add_vision_id=True,
enable_thinking=False,
fps=1,
)
for msg in conversations
]
process_kwargs: dict[str, Any] = {
"return_video_kwargs": True,
"return_video_metadata": True,
}
image_processor = getattr(self._processor, "image_processor", None)
if image_processor is not None and hasattr(image_processor, "patch_size"):
process_kwargs["image_patch_size"] = image_processor.patch_size
image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs)
videos: list[Any] | None = None
video_metadatas: list[Any] | None = None
if video_inputs:
if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2:
videos_seq, metadatas_seq = zip(*video_inputs, strict=False)
videos = list(videos_seq)
video_metadatas = list(metadatas_seq)
else:
videos = list(video_inputs)
processor_kwargs: dict[str, Any] = {
"text": texts,
"images": image_inputs,
"padding": True,
"truncation": False,
"max_length": self.max_length,
"return_tensors": "pt",
"do_resize": False,
}
if videos is not None:
processor_kwargs["videos"] = videos
if video_metadatas is not None:
processor_kwargs["video_metadata"] = video_metadatas
if video_kwargs:
processor_kwargs.update(video_kwargs)
encoded = self._processor(**processor_kwargs)
# Write Robometer-specific token ids and the video patch merge size into
# the encoded batch so `RobometerRewardModel` doesn't need its own
# tokenizer at inference (EO1-style separation: the processor owns the
# tokenizer, the model owns the backbone and heads).
tokenizer = self._processor.tokenizer
encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>")
encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>")
encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>")
video_processor = getattr(self._processor, "video_processor", None)
encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14))
return encoded
def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]:
pil_frames = _frames_to_pil(frames)
prompt = PROGRESS_PROMPT.format(task=task)
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
if self.use_multi_image:
for image in pil_frames:
content.append({"type": "image", "image": image})
if self.use_per_frame_progress_token:
content.append({"type": "text", "text": "<|prog_token|>"})
else:
content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0})
return [{"role": "user", "content": content}]
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"base_model_id": self.base_model_id,
"image_key": self.image_key,
"task_key": self.task_key,
"default_task": self.default_task,
"max_frames": self.max_frames,
"use_multi_image": self.use_multi_image,
"use_per_frame_progress_token": self.use_per_frame_progress_token,
"max_length": self.max_length,
}
def make_robometer_pre_post_processors(
config: RobometerConfig,
dataset_stats: dict[str, dict[str, Any]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
The preprocessor adds a batch dimension if needed, runs Robometer's
encoder, and moves everything to the configured device. The
postprocessor is the identity since Robometer outputs a single reward
tensor.
"""
del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor.
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=config.max_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
+1 -3
View File
@@ -21,8 +21,6 @@ from lerobot.utils.import_utils import make_device_from_device_class
from .config import RobotConfig
from .robot import Robot
logger = logging.getLogger(__name__)
def make_robot_from_config(config: RobotConfig) -> Robot:
# TODO(Steven): Consider just using the make_device_from_device_class for all types
@@ -120,7 +118,7 @@ def ensure_safe_goal_position(
}
if warnings_dict:
logger.warning(
logging.warning(
"Relative goal position magnitude had to be clamped to be safe.\n"
f"{pformat(warnings_dict, indent=4)}"
)
-4
View File
@@ -23,7 +23,6 @@ from .configs import (
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutStrategyConfig,
@@ -50,7 +49,6 @@ from .inference import (
from .strategies import (
BaseStrategy,
DAggerStrategy,
EpisodicStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
@@ -68,8 +66,6 @@ __all__ = [
"HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig",
"EpisodicStrategy",
"EpisodicStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
+1 -36
View File
@@ -121,35 +121,6 @@ class DAggerPedalConfig:
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("episodic")
@dataclass
class EpisodicStrategyConfig(RolloutStrategyConfig):
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
Keyboard controls:
Right arrow end current episode or reset phase early
Left arrow discard current episode and re-record
Escape stop recording session
In between episodes:
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
- else, the robot is moved smoothly to the position of the teleop leader.
"""
# This only applies if there are no teleop leaders specified.
# When True (default), moves the robot back to the joint positions captured at startup.
# Otherwise, leave the robot in its current position.
reset_to_initial_position: bool = True
# Whether to turn on or off the leader -> follower smooth handover behavior.
# When False, fallback to follower -> leader handover.
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
smooth_leader_to_follower_handover: bool = True
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
@@ -258,13 +229,7 @@ class RolloutConfig:
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
needs_dataset = isinstance(
self.strategy,
(
SentryStrategyConfig,
HighlightStrategyConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
),
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
)
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
@@ -17,7 +17,6 @@
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .episodic import EpisodicStrategy
from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
@@ -28,7 +27,6 @@ __all__ = [
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"EpisodicStrategy",
"RolloutStrategy",
"SentryStrategy",
"create_strategy",
+69 -14
View File
@@ -56,14 +56,10 @@ from typing import Any
import numpy as np
from lerobot.common.control_utils import (
follower_smooth_move_to,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
@@ -73,6 +69,7 @@ from lerobot.utils.utils import log_say
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
@@ -174,6 +171,64 @@ class DAggerEvents:
self.upload_requested.clear()
# ---------------------------------------------------------------------------
# Teleoperator helpers
# ---------------------------------------------------------------------------
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def _follower_smooth_move_to(
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm
to the follower, we bring the follower to the teleop's current pose.
Both ``current`` and ``target`` must be in robot-action key space.
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
# ---------------------------------------------------------------------------
# Input device handlers
# ---------------------------------------------------------------------------
@@ -701,31 +756,31 @@ class DAggerStrategy(RolloutStrategy):
logger.info("Pausing engine - robot holds position")
engine.pause()
if teleop_supports_feedback(teleop) and prev_action is not None:
if _teleop_supports_feedback(teleop) and prev_action is not None:
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
logger.info("Smooth handover: moving leader arm to follower position")
teleop_smooth_move_to(teleop, prev_action)
_teleop_smooth_move_to(teleop, prev_action)
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
logger.info("Entering correction mode - human teleop control")
if not teleop_supports_feedback(teleop) and prev_action is not None:
if not _teleop_supports_feedback(teleop) and prev_action is not None:
logger.info("Smooth handover: sliding follower to teleop position")
obs = robot.get_observation()
teleop_action = teleop.get_action()
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
target = ctx.processors.robot_action_processor((processed, obs))
follower_smooth_move_to(robot, prev_action, target)
_follower_smooth_move_to(robot, prev_action, target)
# unlock the teleop for human control
if teleop_supports_feedback(teleop):
if _teleop_supports_feedback(teleop):
teleop.disable_torque()
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
if teleop_supports_feedback(teleop):
if _teleop_supports_feedback(teleop):
teleop.enable_torque()
elif new_phase == DAggerPhase.AUTONOMOUS:
@@ -735,7 +790,7 @@ class DAggerStrategy(RolloutStrategy):
engine.resume()
# release teleop before resuming the policy
if teleop_supports_feedback(teleop):
if _teleop_supports_feedback(teleop):
teleop.disable_torque()
# ------------------------------------------------------------------

Some files were not shown because too many files have changed in this diff Show More