mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c78023dae7 | |||
| 36d0ba5127 | |||
| dca792951e | |||
| 0a369e104a | |||
| b0cdf99957 | |||
| 733f9768b5 | |||
| 7fe49f9e54 | |||
| e1afb96474 | |||
| f395f36dec | |||
| 738ba9272f | |||
| 2a0495f8c3 | |||
| c3c9c2b089 | |||
| e13c6a6110 | |||
| 140cf2a420 | |||
| c092194cf2 | |||
| b858ba1b6c | |||
| e870af119f | |||
| 4174c3b303 |
@@ -65,9 +65,6 @@ repos:
|
||||
name: Format Markdown with Prettier
|
||||
types_or: [markdown, mdx]
|
||||
args: [--prose-wrap=preserve]
|
||||
# Jinja2 model-card templates use a .md extension but contain {% ... %} /
|
||||
# {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops).
|
||||
exclude: ^src/lerobot/templates/.*\.md$
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
|
||||
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -58,7 +58,7 @@ action = model.select_action(obs)
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
|
||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
||||
|
||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||
|
||||
@@ -101,13 +101,11 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
|
||||
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
|
||||
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
@@ -135,7 +133,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -143,7 +140,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -9,8 +9,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: lelab
|
||||
title: LeLab - Lerobot GUI
|
||||
- local: bring_your_own_policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
@@ -45,8 +43,6 @@
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
@@ -65,8 +61,6 @@
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: molmoact2
|
||||
title: MolmoAct2
|
||||
- local: vla_jepa
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
@@ -81,8 +75,6 @@
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
- local: robometer
|
||||
title: ROBOMETER
|
||||
- local: topreward
|
||||
title: TOPReward
|
||||
title: "Reward Models"
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` watches each episode's video with a vision-language
|
||||
model (VLM) and writes natural-language annotations back into your
|
||||
dataset. It fills the two language columns from the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — straight into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
|
||||
memory, interjections, speech, and visual Q&A that a policy can be
|
||||
trained on.
|
||||
|
||||
## How it fits together
|
||||
|
||||
```text
|
||||
your dataset lerobot-annotate
|
||||
(LeRobot v3.1)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ read episodes │
|
||||
└──────────────────────────┬──────────────────────────┘
|
||||
│
|
||||
┌────────────────────┼────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
|
||||
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
|
||||
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
|
||||
└────────────────────┼─────────────────────┘
|
||||
│ each module stages raw JSONL
|
||||
▼ into .annotate_staging/
|
||||
┌─────────────────┐
|
||||
│ validator │ ◀── checks everything
|
||||
└────────┬────────┘
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ writer │
|
||||
└────────┬────────┘
|
||||
▼
|
||||
data/chunk-*/file-*.parquet
|
||||
(+ meta/info.json tools)
|
||||
```
|
||||
|
||||
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
|
||||
VLM. Each module stages its output to disk, a validator checks it, and a
|
||||
single writer rewrites the dataset shards in place.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
Each module emits a few kinds of annotation ("styles"), routed to one of
|
||||
the two language columns:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | --------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections` |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
### How subtasks are generated
|
||||
|
||||
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
|
||||
it uses a two-step **describe → segment** flow:
|
||||
|
||||
1. **Describe** — the VLM narrates only what it actually sees in the
|
||||
chosen camera (no guessing about the task).
|
||||
2. **Segment** — that description is fed back in, and the VLM splits the
|
||||
episode into consecutive atomic subtasks.
|
||||
|
||||
Both passes see the episode as **timestamped contact sheets** — frames
|
||||
sampled at `frames_per_second` (0.5s by default) and packed into JPEG
|
||||
grids with each frame's time burned into its corner, so the VLM cites
|
||||
exact boundary times directly. This is far cheaper in vision tokens than
|
||||
one image per frame, so the sampling can stay dense; episodes longer than
|
||||
`max_frames_per_prompt` are split into windows at the same density and
|
||||
merged. Both prompts also carry a causal **event-boundary** definition (a
|
||||
new event starts when an object becomes held / is released / reaches a new
|
||||
location / a lid changes state / contents move) to sharpen where cuts land.
|
||||
|
||||
The resulting spans are then stitched into a gap-free, full-episode
|
||||
cover, so **every frame has exactly one active subtask**. See
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
for the production settings (single camera, timestamped contact sheets,
|
||||
auto-windowed subtask generation).
|
||||
|
||||
### Tools
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet. The tool
|
||||
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
|
||||
After every run, the pipeline makes sure the canonical `say` schema is in
|
||||
that list, keeping any tools you declared beforehand.
|
||||
|
||||
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
|
||||
pipeline preserves whatever is already there. That makes the tool visible
|
||||
to the chat template, so the model can learn to _generate_ the call. The
|
||||
runtime layer that actually _executes_ a generated call (the `Tool`
|
||||
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
|
||||
this PR — the [Tools](./tools) doc marks those pieces as
|
||||
not-yet-implemented.
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
|
||||
The repo ships a launcher script you copy and tweak for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
|
||||
that:
|
||||
|
||||
1. installs `lerobot` (from `main`) plus the annotation extras,
|
||||
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
|
||||
drives it over the OpenAI-compatible API,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
with `lerobot-annotate`,
|
||||
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
|
||||
back to `--repo_id` in place if you leave that unset).
|
||||
|
||||
To use a different dataset, model, or hub repo, edit the `CMD` block in
|
||||
the script. Every flag there maps directly to a `lerobot-annotate` flag
|
||||
(run `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Key options
|
||||
|
||||
These are the flags you'll reach for most often. Run
|
||||
`lerobot-annotate --help` for everything else; the defaults are tuned for
|
||||
short manipulation episodes.
|
||||
|
||||
### Dataset in / out
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------- | ------- | ----------------------------------------------------------------------- |
|
||||
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
|
||||
| `--root` | — | Annotate a local dataset directory instead. |
|
||||
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
|
||||
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
|
||||
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
|
||||
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
|
||||
|
||||
### Which modules run
|
||||
|
||||
Every module is on by default and can be toggled independently (set to
|
||||
`false` to skip it, e.g. to iterate on one module at a time):
|
||||
|
||||
| Flag | Default | Turns off |
|
||||
| ------------------------- | ------- | ----------------------------------- |
|
||||
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
|
||||
| `--interjections.enabled` | `true` | interjections + speech atoms |
|
||||
| `--vqa.enabled` | `true` | the VQA pairs |
|
||||
|
||||
### The VLM (`--vlm.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
|
||||
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
|
||||
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
|
||||
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
|
||||
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
|
||||
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
|
||||
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
|
||||
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
|
||||
| `--vlm.temperature` | `0.2` | Sampling temperature. |
|
||||
|
||||
### Subtasks / plan / memory (`--plan.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). |
|
||||
| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. |
|
||||
| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). |
|
||||
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
|
||||
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
|
||||
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
|
||||
| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. |
|
||||
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
|
||||
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
|
||||
|
||||
### Interjections + VQA
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
|
||||
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
|
||||
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
|
||||
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
|
||||
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
|
||||
|
||||
## Contributing new modules
|
||||
|
||||
The pipeline is built to grow, and **contributions are very welcome** —
|
||||
a brand-new module (say, trajectory traces or affordances), a new prompt
|
||||
template, a smarter grounding flow, or quality fixes to the existing
|
||||
`plan` / `interjections` / `vqa` modules.
|
||||
|
||||
Every module lives under
|
||||
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
|
||||
client and the keyframe cache, writes its raw output to the staging
|
||||
tree, and plugs into the executor as its own phase. Got an idea? Open an
|
||||
issue or PR on [the repo](https://github.com/huggingface/lerobot).
|
||||
|
||||
## How recipes consume the output
|
||||
|
||||
The annotations are meant to be read by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)). Typically:
|
||||
|
||||
- low-level / high-level / memory-update branches read
|
||||
`subtask` / `plan` / `memory` from `language_persistent`.
|
||||
- an interjection-response branch reads `interjection` events plus the
|
||||
paired speech atom (merged into one assistant turn via `tool_calls_from`)
|
||||
and the matching `plan` refresh at the same timestamp.
|
||||
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
|
||||
`language_events`.
|
||||
|
||||
## Why state and events are split
|
||||
|
||||
Two ideas shape the design:
|
||||
|
||||
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
|
||||
`plan`, `memory`) apply to the whole episode and answer "what's true
|
||||
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
|
||||
the one frame whose timestamp matches. Timestamps are copied straight
|
||||
from the source parquet — never recomputed in floating point.
|
||||
2. **One VLM pass.** All three modules share a single VLM client (the
|
||||
OpenAI-compatible client talking to the job's vLLM server), so you pay
|
||||
for one model load per dataset, not three.
|
||||
|
||||
## Re-running a single module
|
||||
|
||||
Each module stages its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
|
||||
prompt iteration cheap: re-running one module overwrites only its own
|
||||
JSONL, then the writer recomposes the final parquet. Disable modules you
|
||||
don't want with `--plan.enabled=false` (and likewise
|
||||
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
|
||||
|
||||
## What the validator checks
|
||||
|
||||
Before the writer runs, `StagingValidator` confirms:
|
||||
|
||||
- every event row lands exactly on a real frame timestamp;
|
||||
- no speech / interjection pairs are left orphaned;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (a warning, not an error);
|
||||
- each VQA assistant `content` is valid JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row goes to the column chosen by `column_for_style(style)`.
|
||||
|
||||
Any error aborts the writer. Pass `--skip_validation=true` to override
|
||||
while debugging.
|
||||
|
||||
## Where each module's ideas come from
|
||||
|
||||
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
for atom granularity ("pick up one piece of lettuce", "place bowl to
|
||||
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
|
||||
for "how, not what" detail.
|
||||
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
|
||||
keep only the minimal relevant information — preserve outcomes, drop
|
||||
specific attributes.
|
||||
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom
|
||||
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
|
||||
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
|
||||
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies
|
||||
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
|
||||
grounding. Pi0.7 also grounds answers across abstraction levels.
|
||||
|
||||
When improving a module, tweak its prompt template in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
|
||||
rewriting from scratch.
|
||||
|
||||
## Roughly how much it costs
|
||||
|
||||
Per episode, the pipeline makes about `max_steps` plan calls,
|
||||
`max_interjections_per_episode` interjection calls, and
|
||||
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
|
||||
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
|
||||
~50 VLM calls.
|
||||
|
||||
Storage stays small: `language_persistent` is at most tens of KB per
|
||||
episode (parquet dictionary-encodes the one entry that repeats across
|
||||
frames), and `language_events` is empty on most frames — its size scales
|
||||
with the number of emissions, not `num_frames × num_emissions`.
|
||||
@@ -157,14 +157,6 @@ finally:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Working with depth
|
||||
|
||||
The Intel RealSense and Reachy 2 cameras can capture both color and depth in lockstep. Calling `read()` returns the **color** frame as `(H, W, 3)` `uint8`. Calling `read_depth()` returns the **depth map** as `(H, W, 1)` `uint16`, where each pixel value is the distance from the sensor expressed in **millimetres**. A pixel value of `0` typically means "no measurement available" (out-of-range, occluded, or low-confidence).
|
||||
|
||||
During recording, the control loop peeks the freshest buffered frames non-blockingly via `read_latest()` (color) and `read_latest_depth()` (depth), adding the depth map as a sibling feature (e.g. `front_depth` next to `front`).
|
||||
|
||||
For how depth streams are stored and encoded when recording a dataset, see the [Depth streams](./video_encoding_parameters#depth-streams) section of the video encoding guide.
|
||||
|
||||
## Use your phone's camera
|
||||
|
||||
<hfoptions id="use phone">
|
||||
|
||||
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
||||
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `bi_openarm_mini` - Bimanual OpenArm Mini
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
|
||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
|
||||
@@ -647,6 +647,5 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -117,7 +117,7 @@ lerobot-rollout \
|
||||
--strategy.num_episodes=20 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--teleop.type=openarm_mini \
|
||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
```
|
||||
@@ -157,44 +157,6 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# LeLab - LeRobot Guide
|
||||
|
||||
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
|
||||
|
||||
> [!TIP]
|
||||
> For now LeLab is compatible only with SO-ARM101
|
||||
|
||||
<Youtube id="VqyKUuW9V1g" />
|
||||
|
||||
### Installation
|
||||
|
||||
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
|
||||
|
||||
```
|
||||
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||
```
|
||||
|
||||
After install, run `lelab` from your terminal anytime to start the app.
|
||||
|
||||
### Features
|
||||
|
||||
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
|
||||
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
|
||||
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
|
||||
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
|
||||
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
|
||||
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
|
||||
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
|
||||
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
|
||||
@@ -275,7 +275,7 @@ A converter aggregates per‑episode files into larger shards and writes episode
|
||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||
|
||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
```
|
||||
|
||||
**What it does**
|
||||
|
||||
@@ -238,7 +238,7 @@ your dataset has not been converted with quantile statistics, you can add them
|
||||
with:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ lerobot-train \
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
At inference time only the Qwen backbone and action head are used; the world model is not needed.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
|
||||
If you have existing datasets in v2.1 format, use the migration tool:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id your_id/existing_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
# ROBOMETER
|
||||
|
||||
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
|
||||
|
||||
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
|
||||
**Project**: [robometer.github.io](https://robometer.github.io/)
|
||||
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
|
||||
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
|
||||
|
||||
## Overview
|
||||
|
||||
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
|
||||
|
||||
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
|
||||
- **Success head**: predicts per-frame task success probability.
|
||||
- **Preference head**: predicts which of two trajectories better completes the task during training.
|
||||
|
||||
The paper trains ROBOMETER with a composite objective:
|
||||
|
||||
```text
|
||||
L = L_pref + L_prog + L_succ
|
||||
```
|
||||
|
||||
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
|
||||
|
||||
## What the LeRobot Integration Covers
|
||||
|
||||
- Standard `reward_model.type=robometer` configuration through LeRobot.
|
||||
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
|
||||
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
|
||||
- Dense, frame-level progress and success predictions internally.
|
||||
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
|
||||
|
||||
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the ROBOMETER dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[robometer]"
|
||||
```
|
||||
|
||||
If you use `uv` directly from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra robometer
|
||||
```
|
||||
|
||||
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
ROBOMETER expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets, the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | ------------------------ | ----------------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
|
||||
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
|
||||
|
||||
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
|
||||
|
||||
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
|
||||
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the Reward Model Directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||
|
||||
cfg = RobometerConfig(
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
reward_output="progress",
|
||||
)
|
||||
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
|
||||
```
|
||||
|
||||
### Encode Frames and Compute a Reward
|
||||
|
||||
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
|
||||
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
|
||||
# task: str
|
||||
encoder = RobometerEncoderProcessorStep(
|
||||
base_model_id=cfg.base_model_id,
|
||||
use_multi_image=cfg.use_multi_image,
|
||||
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
|
||||
max_frames=cfg.max_frames,
|
||||
)
|
||||
|
||||
encoded = encoder.encode_samples([(frames, task)])
|
||||
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
|
||||
|
||||
reward = reward_model.compute_reward(batch)
|
||||
```
|
||||
|
||||
`reward` is a tensor of shape `(batch_size,)`.
|
||||
|
||||
### Use the Reward Factory
|
||||
|
||||
You can also instantiate ROBOMETER through the reward factory:
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"robometer",
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Backbone and Vocabulary
|
||||
|
||||
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
|
||||
|
||||
```text
|
||||
<|split_token|>
|
||||
<|reward_token|>
|
||||
<|pref_token|>
|
||||
<|sim_token|>
|
||||
<|prog_token|>
|
||||
```
|
||||
|
||||
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
|
||||
|
||||
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
|
||||
|
||||
### Progress Prediction
|
||||
|
||||
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
|
||||
|
||||
### Success Prediction
|
||||
|
||||
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
|
||||
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
|
||||
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
|
||||
|
||||
## References
|
||||
|
||||
- [ROBOMETER project](https://robometer.github.io/)
|
||||
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
|
||||
- [Original ROBOMETER code](https://github.com/robometer/robometer)
|
||||
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
|
||||
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{liang2026robometer,
|
||||
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
|
||||
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
|
||||
year={2026},
|
||||
booktitle={Robotics: Science and Systems 2026},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
|
||||
@@ -11,9 +11,8 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage (RGB and depth cameras are encoded with separate encoders)
|
||||
7. **Re-encode Videos** - Re-encode an existing video dataset's RGB and/or depth streams with new encoder settings
|
||||
8. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -123,15 +122,6 @@ lerobot-edit-dataset \
|
||||
--operation.camera_encoder.g 2 \
|
||||
--operation.camera_encoder.crf 30
|
||||
|
||||
# Convert a dataset that includes depth maps, customizing the depth encoder
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.depth_encoder.depth_min 0.01 \
|
||||
--operation.depth_encoder.depth_max 10.0 \
|
||||
--operation.depth_encoder.use_log true
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -157,42 +147,11 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `camera_encoder`: Video encoder settings applied to RGB cameras — all sub-fields accessible via `--operation.camera_encoder.<field>`. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `depth_encoder`: Video encoder settings applied to depth-map cameras (e.g. from an Intel RealSense). In addition to the standard encoder fields it exposes the depth quantization knobs (`depth_min`, `depth_max`, `shift`, `use_log`), accessible via `--operation.depth_encoder.<field>`. These quantization settings are persisted to the dataset metadata so depth can be dequantized back to physical units on load. See the [Depth streams](./video_encoding_parameters#depth-streams) section for details.
|
||||
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). Depth-map cameras are detected automatically and routed to the `depth_encoder`, while RGB cameras use the `camera_encoder`. All episodes, stats, and tasks are preserved.
|
||||
|
||||
#### Re-encode Videos
|
||||
|
||||
Re-encode the videos of an existing video dataset with different encoder settings, without going back to raw frames. RGB videos use the `camera_encoder` and depth videos use the `depth_encoder`. Provide only the encoder(s) you want to re-encode; the other stream type is left untouched.
|
||||
|
||||
```bash
|
||||
# Re-encode all RGB videos with new settings (saves to lerobot/pusht_reencoded by default)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec h264 \
|
||||
--operation.camera_encoder.pix_fmt yuv420p \
|
||||
--operation.camera_encoder.crf 23
|
||||
|
||||
# Re-encode both RGB and depth videos in a dataset with depth maps
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_depth \
|
||||
--operation.type reencode_videos \
|
||||
--operation.camera_encoder.vcodec libx264 \
|
||||
--operation.depth_encoder.vcodec ffv1
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `camera_encoder`: Encoder settings applied to every RGB video. Omit to skip re-encoding RGB videos.
|
||||
- `depth_encoder`: Encoder settings applied to every depth video. Omit to skip re-encoding depth videos.
|
||||
- `num_workers`: Number of parallel workers for processing.
|
||||
|
||||
> [!NOTE]
|
||||
> When re-encoding depth videos, the existing depth quantization parameters (`depth_min`, `depth_max`, `shift`, `use_log`) and the `is_depth_map` flag are **preserved** — re-encoding only changes the codec/quality of the stored stream, not how depth is dequantized on load.
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
|
||||
@@ -65,76 +65,6 @@ All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
|
||||
|
||||
---
|
||||
|
||||
## Depth streams
|
||||
|
||||
Depth maps (Intel RealSense, Reachy 2) are stored as their **own video streams** alongside the RGB streams. Raw depth (`uint16` millimetres or `float32` metres) can't survive an 8-bit codec, so LeRobot **quantizes** each map to a 12-bit code (`[0, 4095]`) — logarithmically by default, to match the `1/depth` error profile of depth sensors — then packs it into a high-bit-depth pixel format (`gray12le`) and encodes it with a 12-bit codec.
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
A["Raw depth (uint16 mm / float32 m)"] --> B["Clip to depth_min, depth_max"]
|
||||
B --> C["Quantize to 12-bit code 0–4095 (log or linear)"]
|
||||
C --> D["Pack into gray12le"]
|
||||
D --> E["Encode video (hevc Main 12)"]
|
||||
E --> F[("MP4 + metadata: depth_min/max, shift, use_log")]
|
||||
F -. "load time (depth_output_unit)" .-> G["Dequantize to mm or m"]
|
||||
|
||||
classDef input fill:#e3f2fd,stroke:#1565c0,color:#0d47a1;
|
||||
classDef encode fill:#ede7f6,stroke:#5e35b1,color:#311b92;
|
||||
classDef store fill:#fff8e1,stroke:#f9a825,color:#e65100;
|
||||
classDef load fill:#e8f5e9,stroke:#2e7d32,color:#1b5e20;
|
||||
|
||||
class A input;
|
||||
class B,C,D,E encode;
|
||||
class F store;
|
||||
class G load;
|
||||
```
|
||||
|
||||
Configure the depth pipeline through a parallel **`depth_encoder`** block (`DepthEncoderConfig`). It inherits every `VideoEncoderConfig` field (`vcodec`, `pix_fmt`, `crf`, …) and adds four quantizer knobs, set via `--dataset.depth_encoder.<field>`:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
... \
|
||||
--dataset.depth_encoder.vcodec=hevc \
|
||||
--dataset.depth_encoder.depth_min=0.05 \
|
||||
--dataset.depth_encoder.depth_max=5.0 \
|
||||
--dataset.depth_encoder.use_log=true
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ----------- | ------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vcodec` | `str` | `"hevc"` | Defaults to HEVC Main 12 (a 12-bit-capable codec). `ffv1` is a lossless alternative. |
|
||||
| `pix_fmt` | `str` | `"gray12le"` | Single-channel 12-bit pixel format used to carry the quantized codes. |
|
||||
| `depth_min` | `float` | `0.01` | Depth in metres mapped to quantum `0`. Values below are clipped on decode. |
|
||||
| `depth_max` | `float` | `10.0` | Depth in metres mapped to quantum `4095`. Values above are clipped on decode. |
|
||||
| `shift` | `float` | `3.5` | Pre-log offset (metres) used in logarithmic quantization for numerical stability near zero. Must satisfy `depth_min + shift > 0`. |
|
||||
| `use_log` | `bool` | `True` | If `true`, quantize in log-space (recommended for typical depth sensors). Set to `false` for uniform/linear quantization. |
|
||||
|
||||
> [!TIP]
|
||||
> `depth_min`, `depth_max`, and `shift` are always interpreted in **metres**, regardless of the input depth's unit. Inputs are auto-detected: integer arrays (e.g. `uint16` millimetres straight from a RealSense) are treated as millimetres, floating arrays as metres.
|
||||
> Pick `depth_min` / `depth_max` to bracket the actual working range of your sensor — quanta outside that range saturate, which can crush detail at the boundaries.
|
||||
|
||||
Depth features are flagged with `"is_depth_map": true` in `meta/info.json`, and their quantizer settings (`video.depth_min`, `video.depth_max`, `video.shift`, `video.use_log`) are persisted — which is what lets depth be **dequantized back to physical units** on load.
|
||||
|
||||
### Output unit at load time
|
||||
|
||||
`depth_encoder` is a **record-time** concern. The unit that depth maps are dequantized to on _load_ (e.g. during training) is set separately by the read-time flag `--dataset.depth_output_unit`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.depth_output_unit=m \
|
||||
--policy.type=act
|
||||
```
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------- | ----- | ------- | -------------------------------------------------------------------------------------------- |
|
||||
| `depth_output_unit` | `str` | `"mm"` | Physical unit depth maps are dequantized to on load: `"mm"` (millimetres) or `"m"` (metres). |
|
||||
|
||||
> [!TIP]
|
||||
> This is purely a decode-time presentation choice — it does **not** alter the stored video or its metadata, so the same dataset can be read as `mm` or `m` without re-encoding. It has no effect on datasets without depth cameras.
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
|
||||
@@ -152,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"is_depth_map": false,
|
||||
"video.is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
@@ -167,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
|
||||
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
|
||||
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
|
||||
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning on a different embodiment
|
||||
|
||||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||
|
||||
The layers that depend on `action_dim` and `state_dim` are:
|
||||
|
||||
| Layer | Key prefix |
|
||||
| ----------------------------------------- | ----------------------------------- |
|
||||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
**Expected results:**
|
||||
|
||||
| Suite | Episodes | Successes | Success Rate |
|
||||
| -------------- | -------- | --------- | ------------ |
|
||||
| libero_spatial | 100 | 93 | **95.0%** |
|
||||
| libero_object | 100 | 100 | **100.0%** |
|
||||
| libero_goal | 100 | 98 | **98.0%** |
|
||||
| libero_10 | 100 | 96 | **93.0%** |
|
||||
| **Overall** | **400** | **387** | **96.5%** |
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on datasets with a different number of cameras
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||
|
||||
**Default behaviour — view padding / trimming (no action required)**
|
||||
|
||||
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||
|
||||
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||
|
||||
**Option 1 — Disable the world model**
|
||||
|
||||
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
|
||||
|
||||
Spawns one single-GPU ``h200`` job that:
|
||||
|
||||
1. installs ``lerobot`` from ``main`` plus the annotation extras,
|
||||
2. boots one vllm server with Qwen3.6-27B (dense VLM),
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (each episode generates its own subtasks +
|
||||
memory),
|
||||
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
|
||||
run. For larger datasets, scale to ``h200x4`` and raise
|
||||
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
|
||||
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-27B "
|
||||
"--vlm.num_gpus=1 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
# Qwen3.6 ships with thinking on; annotation wants plain JSON answers.
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}'"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
+10
-36
@@ -115,8 +115,8 @@ dataset = [
|
||||
]
|
||||
training = [
|
||||
"lerobot[dataset]",
|
||||
"wandb>=0.24.0,<0.28.0",
|
||||
"lerobot[accelerate-dep]",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
]
|
||||
hardware = [
|
||||
"lerobot[pynput-dep]",
|
||||
@@ -142,8 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
|
||||
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
@@ -178,12 +177,7 @@ unitree_g1 = [
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
|
||||
reachy2 = [
|
||||
"reachy2_sdk>=1.0.15,<1.1.0",
|
||||
"grpcio<=1.73.1",
|
||||
"protobuf<=6.32.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
@@ -205,7 +199,7 @@ wallx = [
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
@@ -218,43 +212,26 @@ groot = [
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
|
||||
# which talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
|
||||
# (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
|
||||
# uv's single unified lock would then cap ``torch`` for every extra
|
||||
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
|
||||
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
|
||||
# install it locally only if you run your own ``vllm serve``.
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -304,7 +281,6 @@ all = [
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -315,7 +291,6 @@ all = [
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
@@ -338,7 +313,6 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
@@ -357,7 +331,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
]
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: subtasks + plan + memory + task augmentation."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# Derive the task from video instead of episode_task: off / if_short / always.
|
||||
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# --- Frame input: timestamped contact sheets (always on) ---------------
|
||||
# The subtask describe/segment passes ALWAYS render the episode as
|
||||
# macrodata/refiner-style contact sheets: sampled frames packed into JPEG
|
||||
# grids with each frame's timestamp burned into its corner, so the VLM
|
||||
# cites the exact source time of a boundary directly. This is far cheaper
|
||||
# in vision tokens than one image per frame (≈2× faster subtask generation
|
||||
# in practice), which is why the sampling is dense by default.
|
||||
#
|
||||
# ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s.
|
||||
frames_per_second: float = 2.0
|
||||
# Frame budget per VLM call (= columns × rows × sheets). When a whole
|
||||
# episode sampled at ``frames_per_second`` exceeds this, the episode is
|
||||
# AUTOMATICALLY split into consecutive windows of
|
||||
# ``max_frames_per_prompt`` frames each (one describe→segment call per
|
||||
# window, still at the full ``frames_per_second`` density), and the
|
||||
# per-window spans are merged + stitched into one contiguous cover. So an
|
||||
# episode of any length is always covered at the full sampling density.
|
||||
max_frames_per_prompt: int = 60
|
||||
contact_sheet_columns: int = 5
|
||||
contact_sheet_frames_per_sheet: int = 20
|
||||
contact_sheet_frame_width: int = 224
|
||||
contact_sheet_quality: int = 84
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# Narrate-only grounding pass before segmenting — best defense against subtasks
|
||||
# invented from the task text (+1 VLM call/episode).
|
||||
subtask_describe_first: bool = True
|
||||
|
||||
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
|
||||
emit_plan: bool = True
|
||||
|
||||
# Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only.
|
||||
# Symmetric counterpart of ``emit_plan``.
|
||||
emit_memory: bool = True
|
||||
|
||||
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
|
||||
|
||||
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
|
||||
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskAugAxesConfig:
|
||||
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
|
||||
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
|
||||
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
|
||||
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
synonym_paraphrase: int = 3
|
||||
omit_arm: int = 3
|
||||
omit_orientation: int = 2
|
||||
omit_grasp_method: int = 2
|
||||
combined_omissions: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
|
||||
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
# True: ground VQA only on --vlm.camera_key (default: every camera).
|
||||
restrict_to_default_camera: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
|
||||
# auto_serve=True); ``stub`` is for tests.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-27B"
|
||||
|
||||
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# Spawn a server if none answers api_base; False = fail fast on a remote.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command; ``{port}`` substituted per replica.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
|
||||
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
|
||||
max_model_len: int | None = None
|
||||
|
||||
# Camera for keyframes; None → first ``observation.images.*`` key.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
|
||||
|
||||
# Episodes processed concurrently per phase; main knob for saturating the servers.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||
|
||||
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
|
||||
# is on and ``new_repo_id`` unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
|
||||
new_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/``.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend forwarded to ``decode_video_frames``. None →
|
||||
# library default (torchcodec when available, else PyAV). Or pin
|
||||
# ``"torchcodec"`` / ``"pyav"`` explicitly.
|
||||
video_backend: str | None = None
|
||||
|
||||
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
@@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor runs **six phases** in dependency order:
|
||||
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
@@ -1,481 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import VideoEncoderConfig
|
||||
from lerobot.datasets.video_utils import decode_video_frames, reencode_video
|
||||
|
||||
from .reader import EpisodeRecord, snap_to_frame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend forwarded to
|
||||
# :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None``
|
||||
# uses the library default (torchcodec when available, else PyAV).
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
# Serializes decode_video_frames calls: torchcodec hands out one
|
||||
# ``VideoDecoder`` per file from a process-wide cache, and the decoder
|
||||
# is not safe to drive from multiple threads at once.
|
||||
_decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# Only ``video_keys`` are decodable here: the clip/decode paths read
|
||||
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
|
||||
# only for video-stored cameras. Image-stored cameras (also in
|
||||
# ``camera_keys``) would KeyError, so restrict the list — and the
|
||||
# default — to video keys.
|
||||
keys = list(self._meta.video_keys)
|
||||
# Last-resort fallback: if metadata didn't surface any video keys but
|
||||
# the caller explicitly named a camera (``--vlm.camera_key=...``),
|
||||
# trust them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
# Snap each request to the nearest real frame timestamp: callers
|
||||
# sample uniform grids whose points land mid-frame, and
|
||||
# ``decode_video_frames`` rejects queries farther than
|
||||
# ``tolerance_s`` from a decodable frame. Snapping also dedupes
|
||||
# repeat queries through the cache.
|
||||
if record.frame_timestamps:
|
||||
timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps]
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks or extraction
|
||||
failed. Skips re-extract when the cached clip already exists.
|
||||
Re-encodes to H.264 via
|
||||
:func:`lerobot.datasets.video_utils.reencode_video` so the resulting
|
||||
mp4 is decodable by every downstream video processor — stream-copy
|
||||
would inherit the source codec (often AV1 in modern LeRobot
|
||||
datasets), which vllm's libav build cannot decode.
|
||||
"""
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast")
|
||||
try:
|
||||
reencode_video(
|
||||
src,
|
||||
out_path,
|
||||
video_encoder=encoder,
|
||||
overwrite=True,
|
||||
start_time_s=from_timestamp,
|
||||
end_time_s=to_timestamp,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True
|
||||
)
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec when available, PyAV otherwise; ``video_backend`` pins
|
||||
one explicitly). Returns one frame per requested timestamp, or ``[]``
|
||||
if decoding failed — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
try:
|
||||
# The module phases decode under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``) but torchcodec's cached
|
||||
# per-file decoder is single-threaded, so serialize decodes on a
|
||||
# dedicated lock. Frame extraction is a small fraction of episode
|
||||
# wall time (VLM calls dominate), so the contention is cheap.
|
||||
with self._decode_lock:
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as exc:
|
||||
# Log loudly the first time so a silent vqa-module no-op (every
|
||||
# prompt skipped because frames_at returned []) is debuggable from
|
||||
# the job log instead of post-hoc parquet inspection. Subsequent
|
||||
# failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = self._warned_decode_fail
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
self.video_backend,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
|
||||
|
||||
def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image:
|
||||
"""Burn ``timestamp`` (seconds) into the top-left corner of ``image``.
|
||||
|
||||
A solid black badge with white text, so a VLM reading a contact sheet can
|
||||
cite the exact source time of each tile (e.g. ``012.50s``) directly,
|
||||
instead of the caller having to map tile position back to time. Mirrors
|
||||
the macrodata/refiner contact-sheet convention.
|
||||
"""
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
result = image.copy()
|
||||
draw = ImageDraw.Draw(result)
|
||||
font = ImageFont.load_default()
|
||||
label = f"{timestamp:06.2f}s"
|
||||
left, top, right, bottom = draw.textbbox((0, 0), label, font=font)
|
||||
text_w, text_h = right - left, bottom - top
|
||||
pad = max(3, round(min(image.width, image.height) * 0.018))
|
||||
draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0))
|
||||
draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font)
|
||||
return result
|
||||
|
||||
|
||||
def to_contact_sheet_blocks(
|
||||
frames: Sequence[Any],
|
||||
timestamps: Sequence[float],
|
||||
*,
|
||||
columns: int = 5,
|
||||
frames_per_sheet: int = 20,
|
||||
frame_width: int = 224,
|
||||
quality: int = 84,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Pack decoded frames into timestamped JPEG contact-sheet image blocks.
|
||||
|
||||
Each frame is resized to ``frame_width`` wide, stamped with its
|
||||
episode-relative timestamp, and tiled row-major into grids of
|
||||
``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}``
|
||||
block is returned per grid; many frames collapse into a few images, so a
|
||||
long episode's temporal coverage stays dense at a fraction of the vision
|
||||
tokens N separate frames would cost. ``frames`` and ``timestamps`` must be
|
||||
aligned and equal length. Returns ``[]`` for empty input.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
columns = max(1, columns)
|
||||
frames_per_sheet = max(1, frames_per_sheet)
|
||||
rows_per_sheet = math.ceil(frames_per_sheet / columns)
|
||||
|
||||
tiles: list[PIL.Image.Image] = []
|
||||
for ts, frame in zip(timestamps, frames, strict=False):
|
||||
img = _frame_to_pil(frame)
|
||||
if not isinstance(img, PIL.Image.Image):
|
||||
continue
|
||||
img = img.convert("RGB")
|
||||
if img.width != frame_width:
|
||||
height = max(1, round(img.height * frame_width / img.width))
|
||||
img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR)
|
||||
tiles.append(_draw_timestamp_badge(img, float(ts)))
|
||||
if not tiles:
|
||||
return []
|
||||
|
||||
blocks: list[dict[str, Any]] = []
|
||||
for start in range(0, len(tiles), frames_per_sheet):
|
||||
chunk = tiles[start : start + frames_per_sheet]
|
||||
cell_w = max(tile.width for tile in chunk)
|
||||
cell_h = max(tile.height for tile in chunk)
|
||||
sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0))
|
||||
for i, tile in enumerate(chunk):
|
||||
x = (i % columns) * cell_w
|
||||
y = (i // columns) * cell_h
|
||||
sheet.paste(tile, (x, y))
|
||||
# JPEG round-trip at ``quality`` to match the refiner convention and
|
||||
# shrink the wire payload; vision-token count is set by resolution, so
|
||||
# the real saving is the grid packing, not the codec.
|
||||
buf = io.BytesIO()
|
||||
sheet.save(buf, format="JPEG", quality=quality)
|
||||
buf.seek(0)
|
||||
blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")})
|
||||
return blocks
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
_warned_no_camera: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not self._warned_no_camera:
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
|
||||
When ``config.restrict_to_default_camera`` is set, VQA grounds on
|
||||
only the provider's default camera (the single ``--vlm.camera_key``
|
||||
stream), matching the plan / interjection modules so the whole
|
||||
pipeline focuses on one view.
|
||||
"""
|
||||
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
if getattr(self.config, "restrict_to_default_camera", False):
|
||||
default = getattr(self.frame_provider, "camera_key", None)
|
||||
if default and default in all_cameras:
|
||||
return [default]
|
||||
# ``restrict_to_default_camera`` is set but the configured default
|
||||
# isn't one the provider exposes. Returning it anyway would make
|
||||
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
|
||||
# fall through to every available camera instead.
|
||||
if default:
|
||||
logging.getLogger(__name__).warning(
|
||||
"restrict_to_default_camera is set but camera_key=%r is not in the "
|
||||
"provider's cameras %s; grounding VQA on all available cameras instead.",
|
||||
default,
|
||||
all_cameras,
|
||||
)
|
||||
return all_cameras
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("interjections_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("interjections_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
first_ts = float(frame_timestamps[0])
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(first_ts, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -1,780 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
null_provider,
|
||||
to_contact_sheet_blocks,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Prepended to every describe / segment prompt so the VLM knows the images are
|
||||
# timestamped contact-sheet grids, not a single video, and reads the burned-in
|
||||
# per-tile timestamp when choosing boundaries.
|
||||
def _contact_sheet_preamble(columns: int) -> str:
|
||||
return (
|
||||
"CONTACT SHEETS — how to read the images below:\n"
|
||||
f"- Each image is a grid of sampled video frames, {columns} per row, "
|
||||
"with time running left-to-right then top-to-bottom (row-major).\n"
|
||||
"- Each frame has its timestamp burned into the top-left corner, e.g. "
|
||||
'"012.50s". Use that printed timestamp (not the tile position) when you '
|
||||
"choose start/end times; boundaries should land on or near a printed "
|
||||
"timestamp.\n"
|
||||
"- Frames continue across grids: an action may span the end of one sheet "
|
||||
"and the start of the next, so do not place a boundary just because a new "
|
||||
"image begins.\n\n"
|
||||
)
|
||||
|
||||
|
||||
# Appended to every describe (and segment) prompt. A visual, causal definition
|
||||
# of where one event ends and the next begins — adapted from macrodata/refiner —
|
||||
# to sharpen cut points while the existing prompt keeps owning the imperative
|
||||
# phrasing.
|
||||
_CAUSAL_BOUNDARY_RULES = (
|
||||
"EVENT BOUNDARIES — where one event ends and the next begins:\n"
|
||||
"- Start a new event whenever the world state changes: an object becomes "
|
||||
"held (the gripper closes on it), an object is released (the gripper opens "
|
||||
"and it stays put), an object reaches a new location, a lid/door/drawer "
|
||||
"changes open/closed state, a tool starts or stops affecting a surface, or "
|
||||
"contents visibly move (e.g. poured).\n"
|
||||
"- If a single action changes the same state gradually and continuously, "
|
||||
"keep it as ONE event — do not split it.\n"
|
||||
"- If the same action repeats on different objects or target locations, "
|
||||
"treat each repetition as a separate event.\n"
|
||||
"- Do NOT create boundaries for idle time, camera motion, hesitation, or "
|
||||
"tiny hand adjustments."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Task driving every plan-module prompt: canonical episode_task, or a
|
||||
# video-derived one when it's empty/placeholder (see derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
|
||||
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
|
||||
# free-form n_task_rephrasings; the effective task is always emitted
|
||||
# first so the rotation covers the source-of-truth phrasing.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
variants: list[str] | None = None
|
||||
if self.config.task_aug_axes.enabled and effective_task:
|
||||
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
|
||||
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
if variants is not None:
|
||||
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary (incl. t=0). The plan is a
|
||||
# numbered list of still-todo subtasks, so re-emitting at each
|
||||
# boundary makes it shrink as work progresses — ${plan} at frame t is
|
||||
# exactly what's left to do.
|
||||
if self.config.emit_plan:
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start;
|
||||
# skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only).
|
||||
prior_memory = ""
|
||||
memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else []
|
||||
for i, span in memory_boundaries:
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
@staticmethod
|
||||
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
|
||||
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
|
||||
seen: set[str] = set()
|
||||
rows: list[dict[str, Any]] = []
|
||||
for phrasing in phrasings:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
|
||||
)
|
||||
return rows
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers — every plan-module prompt follows the same shape:
|
||||
# build messages → single VLM call → pull a named field.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prompt: str,
|
||||
window: tuple[float, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""User message combining the (optionally windowed) contact sheets with ``prompt``.
|
||||
|
||||
The prompt is always prefixed with a short explanation of how to read
|
||||
the timestamped grids, so the model treats them as one ordered
|
||||
sequence of frames rather than unrelated images.
|
||||
"""
|
||||
prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt
|
||||
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
|
||||
"""One VLM call → variants along the 5-axis taxonomy.
|
||||
|
||||
Variants from all axes are flattened into a single list (the
|
||||
downstream pipeline doesn't need to know about the per-axis
|
||||
bucketing — every variant becomes a ``task_aug`` row). Order
|
||||
is preserved for reproducibility: synonym_paraphrase first,
|
||||
then omit_arm, then omit_orientation, then omit_grasp_method,
|
||||
then combined_omissions.
|
||||
"""
|
||||
if not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_aug_axes").format(
|
||||
base_task=base_task,
|
||||
n_synonym=axes_cfg.synonym_paraphrase,
|
||||
n_omit_arm=axes_cfg.omit_arm,
|
||||
n_omit_orientation=axes_cfg.omit_orientation,
|
||||
n_omit_grasp_method=axes_cfg.omit_grasp_method,
|
||||
n_combined=axes_cfg.combined_omissions,
|
||||
)
|
||||
result = self.vlm.generate_json([self._text_message(prompt)])[0]
|
||||
if not isinstance(result, dict):
|
||||
return []
|
||||
ordered_axes = (
|
||||
"synonym_paraphrase",
|
||||
"omit_arm",
|
||||
"omit_orientation",
|
||||
"omit_grasp_method",
|
||||
"combined_omissions",
|
||||
)
|
||||
flat: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for axis in ordered_axes:
|
||||
entries = result.get(axis)
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for item in entries:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
key = item.strip().strip('"').strip("'")
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
flat.append(key)
|
||||
return flat
|
||||
|
||||
def _episode_video_block(
|
||||
self, record: EpisodeRecord, window: tuple[float, float] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Timestamped contact sheets for the describe / segmentation prompts.
|
||||
|
||||
Always renders the (optionally windowed) episode as contact sheets:
|
||||
frames sampled at ``frames_per_second`` and packed into timestamped
|
||||
JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole
|
||||
episodes that exceed it are windowed upstream in
|
||||
:meth:`_generate_subtasks` so each call stays within budget while the
|
||||
full episode keeps its sampling density.
|
||||
|
||||
When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE
|
||||
(``ts - w0``) to match the window-relative time frame the
|
||||
segmentation prompt works in (spans are offset back to absolute time
|
||||
afterwards).
|
||||
"""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if window is not None:
|
||||
w0, w1 = float(window[0]), float(window[1])
|
||||
dur = max(0.0, w1 - w0)
|
||||
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_frames_per_prompt)
|
||||
if n <= 1 or dur <= 0.0:
|
||||
timestamps = [0.5 * (w0 + w1)]
|
||||
else:
|
||||
step = dur / (n - 1)
|
||||
timestamps = [w0 + i * step for i in range(n)]
|
||||
frames = self.frame_provider.frames_at(record, timestamps)
|
||||
rel = [ts - w0 for ts in timestamps[: len(frames)]]
|
||||
return self._contact_sheet_blocks(frames, rel)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_frames_per_prompt)
|
||||
timestamps = self._uniform_episode_timestamps(record, n)
|
||||
frames = self.frame_provider.frames_at(record, timestamps)
|
||||
return self._contact_sheet_blocks(frames, timestamps[: len(frames)])
|
||||
|
||||
@staticmethod
|
||||
def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]:
|
||||
"""``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly."""
|
||||
ts = record.frame_timestamps
|
||||
if n >= len(ts):
|
||||
return [float(t) for t in ts]
|
||||
t0, t_last = float(ts[0]), float(ts[-1])
|
||||
if t_last <= t0 or n <= 1:
|
||||
return [t0] * max(1, n)
|
||||
step = (t_last - t0) / (n - 1)
|
||||
return [t0 + i * step for i in range(n)]
|
||||
|
||||
def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]:
|
||||
"""Build timestamped contact-sheet image blocks from decoded frames."""
|
||||
return to_contact_sheet_blocks(
|
||||
frames,
|
||||
timestamps,
|
||||
columns=self.config.contact_sheet_columns,
|
||||
frames_per_sheet=self.config.contact_sheet_frames_per_sheet,
|
||||
frame_width=self.config.contact_sheet_frame_width,
|
||||
quality=self.config.contact_sheet_quality,
|
||||
)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections (event-driven). The
|
||||
interjection text is forwarded into the prompt so the refreshed plan
|
||||
reflects the user's correction.
|
||||
"""
|
||||
if not self.config.emit_plan:
|
||||
return
|
||||
existing = staging.read("plan")
|
||||
# Pass the last frame timestamp so the final span is closed (else its
|
||||
# end == start, zero duration, and a refresh inside it is missed).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Generate subtask spans, optionally via a multi-call quality chain.
|
||||
|
||||
Single call (default): watch video → emit subtask JSON.
|
||||
|
||||
Multi-call (opt-in, higher quality, more VLM calls):
|
||||
1. ``subtask_describe_first`` — a grounding pass that narrates
|
||||
ONLY what is visible (no JSON commitment to subtasks yet);
|
||||
its description is injected into the segmentation prompt so
|
||||
the model segments its own grounded observations instead of
|
||||
pattern-matching the task text.
|
||||
2. segmentation — emit subtask JSON (as before).
|
||||
"""
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
effective_task = task if task is not None else record.episode_task
|
||||
|
||||
# ---- Auto-windowing (keeps the full sampling density) --------
|
||||
# Contact sheets are cheap, but a whole long episode sampled at
|
||||
# ``frames_per_second`` can still exceed ``max_frames_per_prompt``.
|
||||
# When it does, split into consecutive windows of exactly that many
|
||||
# frames (one describe→segment call each, still at the full sampling
|
||||
# density), then merge + stitch — so an episode of any length is
|
||||
# covered at full density rather than subsampled into one sparse call.
|
||||
fps = max(1e-6, float(self.config.frames_per_second))
|
||||
n_whole = int(round(episode_duration * fps)) + 1
|
||||
if n_whole > self.config.max_frames_per_prompt:
|
||||
window_s = self.config.max_frames_per_prompt / fps
|
||||
return self._generate_subtasks_windowed(record, effective_task, window_s)
|
||||
|
||||
# ---- Pass 1 (optional): grounding description ----------------
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, effective_task)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the video) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
# ---- Pass 2: segmentation ------------------------------------
|
||||
prompt = self._with_causal_rules(
|
||||
load_prompt("plan_subtasks").format(
|
||||
episode_task=effective_task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
cleaned = self._clean_spans(spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# ---- Full-episode coverage stitch ----------------------------
|
||||
# The VLM can start after t0 or leave gaps, so frames fall through
|
||||
# with no active subtask. Always stitch into a contiguous
|
||||
# [t0, t_last] cover.
|
||||
cleaned = self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _generate_subtasks_windowed(
|
||||
self, record: EpisodeRecord, task: str, window_s: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Subtask generation in fixed-length windows at constant fps.
|
||||
|
||||
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
|
||||
seconds, runs the describe -> segment chain on each window's own
|
||||
frames (sampled at ``frames_per_second``), offsets
|
||||
each window's spans back to absolute episode time, then merges +
|
||||
stitches into a contiguous whole-episode cover.
|
||||
"""
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
all_spans: list[dict[str, Any]] = []
|
||||
w0 = t0
|
||||
n_windows = 0
|
||||
while w0 < t_last - 1e-6:
|
||||
w1 = min(w0 + window_s, t_last)
|
||||
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
|
||||
n_windows += 1
|
||||
w0 = w1
|
||||
logger.info(
|
||||
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
|
||||
record.episode_index,
|
||||
n_windows,
|
||||
window_s,
|
||||
len(all_spans),
|
||||
)
|
||||
# Merge across windows: clamp to the absolute episode, sort, and
|
||||
# frame-snap to distinct starts (handles any boundary collisions).
|
||||
cleaned = self._clean_spans(all_spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
return self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
def _subtasks_for_window(
|
||||
self, record: EpisodeRecord, task: str, w0: float, w1: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run describe -> segment on one ``[w0, w1]`` window.
|
||||
|
||||
The model works in window-RELATIVE time ``[0, L]`` (it perceives
|
||||
the window as a clip starting at 0); spans are offset back to
|
||||
absolute ``[w0, w1]`` before returning.
|
||||
"""
|
||||
window = (w0, w1)
|
||||
win_len = max(0.0, w1 - w0)
|
||||
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, task, window=window)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video clip and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the clip) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
prompt = self._with_causal_rules(
|
||||
load_prompt("plan_subtasks").format(
|
||||
episode_task=task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{win_len:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||
# Window-relative clamp; no frame-snap dedupe yet (done on the
|
||||
# merged absolute set).
|
||||
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# Offset window-relative spans back to absolute episode time.
|
||||
for s in cleaned:
|
||||
s["start"] = w0 + float(s["start"])
|
||||
s["end"] = w0 + float(s["end"])
|
||||
return cleaned
|
||||
|
||||
def _stitch_full_coverage(
|
||||
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Make subtask spans tile the full episode with no gaps.
|
||||
|
||||
* The first subtask starts at the episode's first frame ``t0``
|
||||
(any idle / approach before the first labelled action is folded
|
||||
into it), so every early frame has an active subtask.
|
||||
* Each subtask's ``end`` is snapped to the next subtask's
|
||||
``start`` (gaps between spans are closed), and the final
|
||||
subtask's ``end`` extends to the last frame ``t_last``.
|
||||
|
||||
Starts are otherwise left as the (already frame-snapped, distinct)
|
||||
values the VLM produced — only the FIRST start is pulled
|
||||
back to ``t0``, which can't collide with a later span because it
|
||||
was already the earliest. Purely deterministic; runs after the
|
||||
VLM passes.
|
||||
"""
|
||||
if not spans or not record.frame_timestamps:
|
||||
return spans
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
spans = sorted(spans, key=lambda s: float(s["start"]))
|
||||
spans[0]["start"] = t0
|
||||
for i in range(len(spans) - 1):
|
||||
spans[i]["end"] = float(spans[i + 1]["start"])
|
||||
spans[-1]["end"] = t_last
|
||||
for s in spans:
|
||||
if float(s["end"]) < float(s["start"]):
|
||||
s["end"] = float(s["start"])
|
||||
return spans
|
||||
|
||||
@staticmethod
|
||||
def _with_causal_rules(prompt: str) -> str:
|
||||
"""Append the causal event-boundary rules to a describe/segment prompt."""
|
||||
return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}"
|
||||
|
||||
def _clean_spans(
|
||||
self,
|
||||
spans: Any,
|
||||
record: EpisodeRecord,
|
||||
bounds: tuple[float, float] | None = None,
|
||||
dedupe: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
|
||||
|
||||
``bounds`` overrides the clamp range — pass the window's
|
||||
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
|
||||
``None`` to clamp to the whole episode ``[t0, t_last]``.
|
||||
``dedupe`` runs the frame-snap distinct-start step; skip it for
|
||||
window-relative spans (frame snapping is done once on the merged,
|
||||
absolute-time set).
|
||||
"""
|
||||
if not spans:
|
||||
return []
|
||||
if bounds is not None:
|
||||
lo, hi = float(bounds[0]), float(bounds[1])
|
||||
else:
|
||||
lo = record.frame_timestamps[0]
|
||||
hi = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(lo, min(start, hi))
|
||||
end = max(lo, min(end, hi))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
if dedupe:
|
||||
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
return cleaned
|
||||
|
||||
def _describe_episode(
|
||||
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
|
||||
) -> str:
|
||||
"""Grounding pass: free-form chronological description of the (windowed) video."""
|
||||
prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task))
|
||||
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else ""
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
No VLM call: a plain numbered list keeps the plan aligned with the
|
||||
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
|
||||
cost a round-trip per episode/refresh and could diverge).
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
|
||||
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
|
||||
interjections, and ``run_episode`` at each boundary), only subtasks
|
||||
starting at or after ``refresh_t`` are included — so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("plan_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -1,12 +0,0 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -1,46 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -1,36 +0,0 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -1,27 +0,0 @@
|
||||
You are watching a teleoperated robot demonstration from a single
|
||||
camera. The user asked the robot to: "{episode_task}"
|
||||
|
||||
This is an OBSERVATION pass. Watch the entire clip and describe, in
|
||||
chronological order, ONLY what the robot physically does — the concrete
|
||||
motions, approaches, contacts, grasps, releases, and relocations you can
|
||||
actually SEE in the frames.
|
||||
|
||||
Hard rules:
|
||||
- Describe only motion visible in the video. Do NOT use the task
|
||||
instruction to guess steps that aren't shown. The instruction is the
|
||||
goal; the video is ground truth.
|
||||
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
|
||||
the single field below. Just narrate what happens.
|
||||
- Give an approximate timestamp (in seconds) for each distinct event,
|
||||
e.g. "0.0-1.4s: the base drives forward toward the stove".
|
||||
- Do NOT invent objects, grasps, destinations, or steps. If the robot
|
||||
only does one thing (e.g. it just navigates and the clip ends), say
|
||||
exactly that and nothing more.
|
||||
- Be concrete and literal. "the gripper closes on the mug" — not "the
|
||||
robot prepares to make coffee".
|
||||
|
||||
Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"description": "<chronological, timestamped description of ONLY what is visible>"
|
||||
}}
|
||||
@@ -1,112 +0,0 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{observation_block}GROUNDING — read this first, it overrides everything below:
|
||||
- Label ONLY what the robot actually does in the video. Every subtask
|
||||
you emit must correspond to motion you can SEE in specific frames.
|
||||
- Do NOT invent, anticipate, or pad. If the robot only does one thing
|
||||
(e.g. it just navigates to a location and the clip ends), emit
|
||||
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
|
||||
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
|
||||
subtasks than the ceiling is not just allowed, it is expected for
|
||||
short / atomic demonstrations. One correct subtask is far better
|
||||
than several invented ones.
|
||||
- If the video does not clearly show the action implied by the task,
|
||||
describe what you actually see — do NOT fabricate the task's steps
|
||||
from the instruction text. The instruction tells you the goal; the
|
||||
VIDEO is the ground truth for what happened.
|
||||
|
||||
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
|
||||
fires ONLY on the spatial situation it names; it must not change how you
|
||||
label any other situation):
|
||||
- STACK vs PUT: if an object is placed ON TOP OF another specific object
|
||||
(not on a flat table / shelf / counter), use "stack ... on ...", not
|
||||
"put". "stack blue book on green book", NOT "put blue book on table".
|
||||
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
|
||||
receptacle (push-fit), use "insert ... into ...", not "put".
|
||||
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
|
||||
on the object and the object moves WITH the hand, it is "pick up" /
|
||||
"retrieve" (object leaves its location). If the gripper OPENS and the
|
||||
object stays where the hand left it, it is "put" / "place" (object
|
||||
arrives at a location). Decide by which way the object moves, not by
|
||||
where the hand ends up.
|
||||
- POUR vs PUT: only use "pour" when the source is tilted and contents
|
||||
flow out; moving a full container without tilting is "put"/"place".
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,67 +0,0 @@
|
||||
You are generating structured augmentations of a robot task instruction
|
||||
for training a language-conditioned policy. Unlike free-form rephrasing,
|
||||
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
|
||||
a specific element of the task while preserving its meaning.
|
||||
|
||||
Original task: "{base_task}"
|
||||
|
||||
Produce variants along five named axes. Each axis has a target count.
|
||||
The whole batch should expose the policy to maximum linguistic diversity
|
||||
WITHOUT changing what the robot is supposed to do.
|
||||
|
||||
Axes and target counts:
|
||||
|
||||
synonym_paraphrase ({n_synonym}):
|
||||
Different wording / verbs / sentence structure. ALL information
|
||||
from the original task is preserved — same object, same arm
|
||||
specification if present, same orientation if present, same grasp
|
||||
if present.
|
||||
|
||||
omit_arm ({n_omit_arm}):
|
||||
Drop the left/right/both arm specification from the task. Skip
|
||||
entirely (emit 0 entries) if the original task does NOT mention an
|
||||
arm. Do not invent an arm specification just to omit it.
|
||||
|
||||
omit_orientation ({n_omit_orientation}):
|
||||
Drop orientation cues (upright, sideways, facing the user,
|
||||
long-edge-first, etc.). Skip entirely if no orientation cue is
|
||||
present in the original task.
|
||||
|
||||
omit_grasp_method ({n_omit_grasp_method}):
|
||||
Drop the grip / grasp method specification (pinch, wrap, hold by
|
||||
the rim, etc.). Skip entirely if no grasp method is mentioned.
|
||||
|
||||
combined_omissions ({n_combined}):
|
||||
Combine TWO of the above omissions simultaneously (e.g. drop both
|
||||
arm and orientation). Skip entirely if fewer than two of (arm,
|
||||
orientation, grasp_method) appear in the original task.
|
||||
|
||||
Hard rules:
|
||||
- Each variant MUST preserve the core action, the target object, AND
|
||||
the goal / destination. Do not change which object is involved, where
|
||||
it goes, or the high-level action. "Navigate to the stove" may become
|
||||
"go to the stove" or "head over to the stove" — it must NEVER become
|
||||
"wander around the kitchen", "explore the room", or anything that
|
||||
drops or generalises the stove destination. If you cannot vary the
|
||||
wording without changing the goal, emit fewer variants.
|
||||
- Only the FIVE listed elements (wording, arm, orientation, grasp
|
||||
method, or a combination) may be varied or omitted. The verb's
|
||||
meaning, the object, and the destination are fixed.
|
||||
- Each variant is plain prose, no markdown, no quotes, no list numbers.
|
||||
- Each variant must be DISTINCT from every other variant in the entire
|
||||
output, both within and across axes. Near-duplicates are not allowed.
|
||||
- If an axis cannot reach its target count because the original task
|
||||
lacks the omittable element, emit fewer entries — do NOT pad the
|
||||
axis with paraphrases that belong to a different axis.
|
||||
- Variants should not all start with verbs — vary sentence structure
|
||||
(some imperative, some polite request, some question).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"synonym_paraphrase": ["<v1>", "<v2>", ...],
|
||||
"omit_arm": ["<v1>", "<v2>", ...],
|
||||
"omit_orientation": ["<v1>", ...],
|
||||
"omit_grasp_method": ["<v1>", ...],
|
||||
"combined_omissions": ["<v1>", ...]
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,17 +0,0 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
@@ -1,216 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
@@ -1,332 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
|
||||
# ``_check_column_routing`` already recorded any unknown-style error;
|
||||
# don't let the same ``column_for_style`` lookup raise here uncaught.
|
||||
try:
|
||||
column = column_for_style(row.get("style"))
|
||||
except ValueError:
|
||||
continue
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
|
||||
return
|
||||
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
@@ -1,617 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output from the server,
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client.
|
||||
|
||||
Only the ``openai`` backend is supported for now. The shipped workflow
|
||||
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
|
||||
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
|
||||
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
|
||||
optionally auto-spawning the server via ``auto_serve`` /
|
||||
``serve_command``). The former in-process ``vllm`` / ``transformers``
|
||||
backends were removed to keep the support surface to the HF Jobs path.
|
||||
|
||||
For ``stub``, construct :class:`StubVlmClient` directly with a responder
|
||||
callable; it is rejected here to make accidental misuse obvious.
|
||||
"""
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend in {"vllm", "transformers"}:
|
||||
raise ValueError(
|
||||
f"backend={config.backend!r} (in-process local model) is not supported for now — "
|
||||
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
|
||||
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
|
||||
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
|
||||
)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _bind_serve_port(cmd: str, port: int) -> str:
|
||||
"""Bind a serve command to ``port``: substitute a ``{port}`` placeholder
|
||||
if present, else append ``--port`` when the command omits it (leaving an
|
||||
explicit ``--port`` untouched). Shared by the single- and parallel-server
|
||||
paths so a serve_command never reaches the server with a literal
|
||||
``{port}``."""
|
||||
if "{port}" in cmd:
|
||||
return cmd.replace("{port}", str(port))
|
||||
if "--port" not in cmd:
|
||||
return f"{cmd} --port {port}"
|
||||
return cmd
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = _bind_serve_port(base_cmd, port)
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
# Bind the single server to ``serve_port`` (what ``api_base`` below
|
||||
# targets): substitute a literal ``{port}`` placeholder, else append
|
||||
# ``--port``. Without this a serve_command carrying ``{port}`` would
|
||||
# reach the server unsubstituted and fail to parse.
|
||||
cmd = _bind_serve_port(cmd, config.serve_port)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
@@ -1,341 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the language-column struct shape.
|
||||
|
||||
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
|
||||
writer infers the parquet struct schema from insertion order, so
|
||||
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
|
||||
"""
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
out: dict[str, Any] = {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
}
|
||||
if with_timestamp:
|
||||
out["timestamp"] = float(row["timestamp"])
|
||||
out["camera"] = None if camera is None else str(camera)
|
||||
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Friendly error from the writer instead of a raw KeyError below;
|
||||
# the validator doesn't check ``role`` yet.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=True)
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=False)
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -105,9 +105,8 @@ def raw_observation_to_observation(
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
if image.dtype == torch.uint8:
|
||||
image = image.type(torch.float32) / 255
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
@@ -436,7 +436,7 @@ class OpenCVCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame (blocking call)
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
@@ -445,9 +445,8 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
@@ -485,8 +484,6 @@ class OpenCVCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -268,13 +268,13 @@ class RealSenseCamera(Camera):
|
||||
)
|
||||
|
||||
if len(found_devices) > 1:
|
||||
serial_numbers = [dev["id"] for dev in found_devices]
|
||||
serial_numbers = [dev["serial_number"] for dev in found_devices]
|
||||
raise ValueError(
|
||||
f"Multiple RealSense cameras found with name '{name}'. "
|
||||
f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
|
||||
)
|
||||
|
||||
serial_number = str(found_devices[0]["id"])
|
||||
serial_number = str(found_devices[0]["serial_number"])
|
||||
return serial_number
|
||||
|
||||
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
|
||||
@@ -332,8 +332,8 @@ class RealSenseCamera(Camera):
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width, 1)
|
||||
of type `np.uint16` (raw depth values in millimeters).
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
@@ -465,8 +465,8 @@ class RealSenseCamera(Camera):
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color/depth frame (blocking call with 10s timeout)
|
||||
2. Stores result in latest_color_frame/latest_depth_frame and updates timestamp (thread-safe)
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -474,9 +474,8 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
@@ -487,8 +486,6 @@ class RealSenseCamera(Camera):
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
if processed_depth_frame.ndim == 2: # (H, W) -> (H, W, 1)
|
||||
processed_depth_frame = processed_depth_frame[..., np.newaxis]
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -525,8 +522,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive(): # pragma: no cover
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
@@ -537,6 +532,7 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -579,6 +575,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -614,73 +611,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[np.uint16]:
|
||||
"""Read the latest depth frame asynchronously, in millimeters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W, 1)``, where each
|
||||
pixel is the distance from the sensor in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in millimeters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.uint16`` of shape ``(H, W, 1)``, where each pixel is the
|
||||
distance from the sensor in millimeters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(f"{self}: cannot read depth — camera was configured with use_depth=False.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
|
||||
@@ -249,9 +249,8 @@ class ZMQCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
stop_event = self.stop_event
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
@@ -293,8 +292,6 @@ class ZMQCamera(Camera):
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
if self.thread.is_alive():
|
||||
logger.warning(f"{self} read thread did not terminate within timeout.")
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -244,72 +243,3 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(
|
||||
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
|
||||
) -> None:
|
||||
state: dict = {"step": step}
|
||||
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
|
||||
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
|
||||
# produced `step`, since both scale how many sampler positions a step consumes (see
|
||||
# compute_sampler_state).
|
||||
if num_processes is not None:
|
||||
state["num_processes"] = num_processes
|
||||
if batch_size is not None:
|
||||
state["batch_size"] = batch_size
|
||||
write_json(state, save_dir / TRAINING_STEP)
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||
|
||||
|
||||
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
|
||||
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
@@ -96,8 +75,6 @@ def save_checkpoint(
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -123,10 +100,6 @@ def save_checkpoint(
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
@@ -139,9 +112,7 @@ def save_checkpoint(
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(
|
||||
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||
)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
@@ -149,8 +120,6 @@ def save_training_state(
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -162,12 +131,10 @@ def save_training_state(
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
|
||||
@@ -35,11 +35,8 @@ from .types import (
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -60,12 +57,8 @@ __all__ = [
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
"DepthEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
# Factories
|
||||
"encoder_config_from_video_info",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import DepthEncoderConfig, VideoEncoderConfig, camera_encoder_defaults, depth_encoder_defaults
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
|
||||
private: bool | None = None
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
@@ -60,8 +60,6 @@ class DatasetRecordConfig:
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Video encoder settings for depth-map MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys.
|
||||
depth_encoder: DepthEncoderConfig = field(default_factory=depth_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -35,17 +35,12 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
# Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres).
|
||||
# Has no effect on datasets without depth cameras.
|
||||
depth_output_unit: str = "mm"
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.depth_output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
|
||||
@@ -177,12 +177,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if self.rename_map and active_cfg.pretrained_path is None:
|
||||
raise ValueError(
|
||||
"`rename_map` requires a pretrained policy checkpoint. "
|
||||
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
|
||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar, Self
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
@@ -36,12 +36,11 @@ HW_VIDEO_CODECS = [
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
|
||||
{"h264", "hevc", "libsvtav1", "ffv1", "auto", *HW_VIDEO_CODECS}
|
||||
)
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
@@ -53,19 +52,6 @@ VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
# Default depth quantization and encoding parameters.
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
DEFAULT_DEPTH_PIX_FMT: str = "gray12le"
|
||||
|
||||
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
|
||||
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
@@ -100,10 +86,6 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Source-data channel count this encoder is expected to handle (3 for RGB,
|
||||
# 1 for depth, etc.)
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 3
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
@@ -112,9 +94,9 @@ class VideoEncoderConfig:
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Parse the ``video.*`` keys of a feature ``info`` block into
|
||||
constructor kwargs.
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
@@ -133,15 +115,7 @@ class VideoEncoderConfig:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> Self:
|
||||
"""Reconstruct an encoder config from a video feature's ``info`` block.
|
||||
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
return cls(**cls._kwargs_from_video_info(video_info))
|
||||
return cls(**kwargs)
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
@@ -164,9 +138,7 @@ class VideoEncoderConfig:
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(
|
||||
self.vcodec, self.pix_fmt, self.get_codec_options(), channels=self._DEFAULT_CHANNELS
|
||||
)
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
@@ -246,10 +218,6 @@ class VideoEncoderConfig:
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
@@ -265,75 +233,3 @@ class VideoEncoderConfig:
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantizer.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"gray12le"``.
|
||||
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "gray12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Layer the depth-specific tuning (``depth_min`` / ``depth_max`` /
|
||||
``shift`` / ``use_log``) on top of the base parser. Missing keys
|
||||
fall back to the class defaults.
|
||||
"""
|
||||
kwargs = super()._kwargs_from_video_info(video_info)
|
||||
video_info = video_info or {}
|
||||
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{name}")
|
||||
if value is not None:
|
||||
kwargs[name] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
|
||||
def encoder_config_from_video_info(video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Build the appropriate encoder config from a feature's ``info`` block.
|
||||
|
||||
Dispatches to :class:`DepthEncoderConfig` when the dict marks the feature
|
||||
as a depth map and to :class:`VideoEncoderConfig`
|
||||
otherwise.
|
||||
|
||||
Args:
|
||||
video_info: A feature's ``info`` dict as persisted in ``info.json``,
|
||||
or ``None`` (treated as an empty dict).
|
||||
|
||||
Returns:
|
||||
A :class:`DepthEncoderConfig` for depth features, otherwise a
|
||||
:class:`VideoEncoderConfig`.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
is_depth = bool(video_info.get("is_depth_map") or video_info.get("video.is_depth_map"))
|
||||
cls: type[VideoEncoderConfig] = DepthEncoderConfig if is_depth else VideoEncoderConfig
|
||||
return cls.from_video_info(video_info)
|
||||
|
||||
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
@@ -82,7 +82,6 @@ __all__ = [
|
||||
"aggregate_stats",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"compute_sampler_state",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
|
||||
@@ -286,8 +286,6 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
chunk_size: int | None = None,
|
||||
concatenate_videos: bool = True,
|
||||
concatenate_data: bool = True,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
|
||||
@@ -305,8 +303,6 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
@@ -355,12 +351,8 @@ def aggregate_datasets(
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(
|
||||
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
|
||||
)
|
||||
data_idx = aggregate_data(
|
||||
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
|
||||
)
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
@@ -375,9 +367,7 @@ def aggregate_datasets(
|
||||
logging.info("Aggregation complete.")
|
||||
|
||||
|
||||
def aggregate_videos(
|
||||
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
|
||||
):
|
||||
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||
|
||||
Handles video file concatenation and rotation based on file size limits.
|
||||
@@ -389,7 +379,6 @@ def aggregate_videos(
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -450,7 +439,7 @@ def aggregate_videos(
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new file - offset is 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
@@ -488,7 +477,7 @@ def aggregate_videos(
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
@@ -504,7 +493,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
data_files_size_in_mb: Maximum size for data files in MB.
|
||||
chunk_size: Maximum number of files per chunk.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
@@ -550,7 +538,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
contains_images=contains_images,
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
concatenate=concatenate_data,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
@@ -627,7 +614,6 @@ def append_or_create_parquet_file(
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
concatenate: bool = True,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -644,7 +630,6 @@ def append_or_create_parquet_file(
|
||||
contains_images: Whether the data contains images requiring special handling.
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
@@ -664,7 +649,7 @@ def append_or_create_parquet_file(
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if not concatenate or dst_size + src_size >= max_mb:
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
|
||||
@@ -59,8 +59,6 @@ class RunningQuantileStats:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
# Promote integer and low-precision inputs before computing squared statistics.
|
||||
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
@@ -506,10 +504,8 @@ def compute_episode_stats(
|
||||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||||
|
||||
Note:
|
||||
For 'image'/'video' features, stats are computed per channel and kept with a
|
||||
leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
|
||||
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
|
||||
this rescaling and remain in their stored units.
|
||||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||||
per-channel values when dtype is 'image' or 'video'.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
@@ -533,12 +529,8 @@ def compute_episode_stats(
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
normalization_factor = (
|
||||
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
|
||||
)
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
@@ -558,10 +550,8 @@ def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape not in ((3, 1, 1), (1, 1, 1)):
|
||||
raise ValueError(
|
||||
f"Shape of quantile '{key}' must be (3,1,1) or (1,1,1) but is {value.shape} instead."
|
||||
)
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -337,25 +337,6 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos or images.
|
||||
|
||||
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
||||
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
||||
"""
|
||||
|
||||
def _is_depth(ft: dict) -> bool:
|
||||
info = ft.get("info") or {}
|
||||
video_info = ft.get("video_info") or {}
|
||||
return (
|
||||
info.get("is_depth_map", False)
|
||||
or info.get("video.is_depth_map", False)
|
||||
or video_info.get("video.is_depth_map", False)
|
||||
)
|
||||
|
||||
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
@@ -599,51 +580,29 @@ class LeRobotDatasetMetadata:
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
preserve_keys: Iterable[str] | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate or refresh per-feature video info in ``info.json``.
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Two modes, selected by ``preserve_keys``:
|
||||
|
||||
- **Populate** (``None``, default): write info for video keys that lack it,
|
||||
skip the rest. Used when first encoding a dataset.
|
||||
- **Refresh** (any iterable): re-probe and overwrite existing info, keeping
|
||||
the listed keys. Used after re-encoding to preserve data-intrinsic entries
|
||||
(``is_depth_map``, depth quantization params) while codec params change.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
video_encoder: Encoder configuration used to produce the
|
||||
camera_encoder: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
preserve_keys: ``None`` (default) for populate-once mode. An iterable
|
||||
(possibly empty) switches to refresh mode, keeping these keys'
|
||||
existing values while recomputing the rest.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
refresh = preserve_keys is not None
|
||||
preserve_set = set(preserve_keys or ())
|
||||
for key in video_keys:
|
||||
existing = self.features[key].get("info") or {}
|
||||
# ``is_depth_map`` is written at feature creation and does not count as real video info here.
|
||||
already_populated = bool(set(existing.keys()) - {"is_depth_map"})
|
||||
# Populate-once: never clobber info that has already been written unless a refresh is requested.
|
||||
if already_populated and not refresh:
|
||||
continue
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
new_info = get_video_info(video_path, video_encoder=video_encoder)
|
||||
# Drop preserved keys so the existing values win on merge.
|
||||
new_info = {k: v for k, v in new_info.items() if k not in preserve_set}
|
||||
self.info.features[key]["info"] = {**existing, **new_info}
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
|
||||
@@ -22,10 +22,7 @@ from pathlib import Path
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.configs.video import DepthEncoderConfig
|
||||
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .depth_utils import dequantize_depth
|
||||
from .feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
@@ -54,7 +51,6 @@ class DatasetReader:
|
||||
delta_timestamps: dict[str, list[float]] | None,
|
||||
image_transforms: Callable | None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
):
|
||||
"""Initialize the reader with metadata, filtering, and transform config.
|
||||
|
||||
@@ -72,10 +68,6 @@ class DatasetReader:
|
||||
relative timestamp offsets for temporal context windows.
|
||||
image_transforms: Optional torchvision v2 transform applied to
|
||||
visual features.
|
||||
return_uint8: If True, return RGB video frames as raw uint8 tensors
|
||||
instead of normalized float32.
|
||||
depth_output_unit: Physical unit depth maps are dequantized to
|
||||
(``"m"`` or ``"mm"``). Defaults to ``"mm"``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self.root = root
|
||||
@@ -84,7 +76,6 @@ class DatasetReader:
|
||||
self._video_backend = video_backend
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
|
||||
self.hf_dataset: datasets.Dataset | None = None
|
||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||
@@ -95,12 +86,6 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
##TODO(CarolinePascal): Should we rather use a more lightweight structure ?
|
||||
self._depth_encoder_configs: dict[str, DepthEncoderConfig] = {
|
||||
vid_key: DepthEncoderConfig.from_video_info(self._meta.features[vid_key].get("info"))
|
||||
for vid_key in self._meta.depth_keys
|
||||
}
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
@@ -262,18 +247,7 @@ class DatasetReader:
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
is_depth=vid_key in self._meta.depth_keys,
|
||||
)
|
||||
if vid_key in self._meta.depth_keys:
|
||||
depth_encoder = self._depth_encoder_configs[vid_key]
|
||||
frames = dequantize_depth(
|
||||
frames,
|
||||
depth_min=depth_encoder.depth_min,
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_unit=self._depth_output_unit,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -36,14 +36,7 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
@@ -54,7 +47,6 @@ from .compute_stats import (
|
||||
compute_relative_action_stats,
|
||||
)
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .image_writer import write_image
|
||||
from .io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
@@ -69,13 +61,12 @@ from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEPTH_FILE_PATTERN,
|
||||
IMAGE_FILE_PATTERN,
|
||||
VIDEO_DIR,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
|
||||
@@ -270,8 +261,6 @@ def merge_datasets(
|
||||
datasets: list[LeRobotDataset],
|
||||
output_repo_id: str,
|
||||
output_dir: str | Path | None = None,
|
||||
concatenate_videos: bool = True,
|
||||
concatenate_data: bool = True,
|
||||
) -> LeRobotDataset:
|
||||
"""Merge multiple LeRobotDatasets into a single dataset.
|
||||
|
||||
@@ -281,8 +270,6 @@ def merge_datasets(
|
||||
datasets: List of LeRobotDatasets to merge.
|
||||
output_repo_id: Merged dataset identifier.
|
||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||
"""
|
||||
if not datasets:
|
||||
raise ValueError("No datasets to merge")
|
||||
@@ -297,8 +284,6 @@ def merge_datasets(
|
||||
aggr_repo_id=output_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=output_dir,
|
||||
concatenate_videos=concatenate_videos,
|
||||
concatenate_data=concatenate_data,
|
||||
)
|
||||
|
||||
merged_dataset = LeRobotDataset(
|
||||
@@ -609,7 +594,7 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -623,7 +608,7 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
video_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
camera_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
"""
|
||||
from fractions import Fraction
|
||||
|
||||
@@ -648,13 +633,13 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
codec_options = video_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(video_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
codec_options = camera_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = video_encoder.pix_fmt
|
||||
v_out.pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -741,7 +726,7 @@ def _copy_and_reindex_videos(
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
logging.info(f"Processing videos for {video_key}")
|
||||
video_encoder = encoder_config_from_video_info(
|
||||
camera_encoder = VideoEncoderConfig.from_video_info(
|
||||
src_dataset.meta.info.features.get(video_key, {}).get("info")
|
||||
)
|
||||
|
||||
@@ -825,7 +810,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
video_encoder,
|
||||
camera_encoder,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1159,15 +1144,15 @@ def _save_episode_images_for_video(
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
write_image(item[img_key], imgs_dir / frame_pattern.format(frame_index=i))
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
@@ -1199,14 +1184,13 @@ def _save_batch_episodes_images(
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
|
||||
# Define function to save a single image with global frame index
|
||||
# Defined once outside the loop to avoid repeated closure creation
|
||||
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
|
||||
i, item = i_item_tuple
|
||||
write_image(item[img_key_param], imgs_dir / frame_pattern.format(frame_index=base_frame_idx + i))
|
||||
img = item[img_key_param]
|
||||
# Use global frame index for naming
|
||||
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
episode_durations = []
|
||||
@@ -1297,7 +1281,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1311,7 +1295,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
video_encoder: Video encoder settings used for calibration encoding.
|
||||
camera_encoder: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1336,11 +1320,10 @@ def _estimate_frame_size_via_calibration(
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
sample_indices = range(from_idx, from_idx + num_frames)
|
||||
|
||||
# Save calibration frames using the suffix/format the encoder expects.
|
||||
is_depth = img_key in dataset.meta.depth_keys
|
||||
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
|
||||
# Save calibration frames
|
||||
for i, idx in enumerate(sample_indices):
|
||||
write_image(hf_dataset[idx][img_key], calibration_dir / frame_pattern.format(frame_index=i))
|
||||
img = hf_dataset[idx][img_key]
|
||||
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
# Encode calibration video
|
||||
calibration_video_path = calibration_dir / "calibration.mp4"
|
||||
@@ -1348,7 +1331,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
video_encoder=video_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1621,7 +1604,6 @@ def recompute_stats(
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
# TODO: enable image and video stats re-computation
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
@@ -1668,7 +1650,6 @@ def convert_image_to_video_dataset(
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1680,32 +1661,21 @@ def convert_image_to_video_dataset(
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images.
|
||||
output_dir: Root directory where the converted dataset will be stored. When
|
||||
``None``, defaults to ``$HF_LEROBOT_HOME/repo_id``. Equivalent to
|
||||
``new_root`` in ``EditDatasetConfig``.
|
||||
repo_id: Converted dataset identifier. Equivalent to ``new_repo_id`` in
|
||||
``EditDatasetConfig``.
|
||||
camera_encoder: Video encoder settings applied to RGB cameras. When ``None``,
|
||||
:func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings applied to depth-map cameras, including
|
||||
the quantization parameters persisted to the dataset metadata. When
|
||||
``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
episode_indices: Episode indices to convert. When ``None``, all episodes are
|
||||
converted.
|
||||
num_workers: Number of threads for parallel processing.
|
||||
max_episodes_per_batch: Maximum episodes per video batch, to bound memory use.
|
||||
``None`` means no limit.
|
||||
max_frames_per_batch: Maximum frames per video batch, to bound memory use.
|
||||
``None`` means no limit.
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder: Video encoder settings
|
||||
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit)
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDataset` with images encoded as videos.
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
if depth_encoder is None:
|
||||
depth_encoder = depth_encoder_defaults()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
@@ -1730,7 +1700,10 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}")
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
|
||||
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
|
||||
)
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1792,8 +1765,6 @@ def convert_image_to_video_dataset(
|
||||
episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices}
|
||||
|
||||
for img_key in tqdm(img_keys, desc="Processing cameras"):
|
||||
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
|
||||
|
||||
# Estimate size per frame by encoding a small calibration sample
|
||||
# This provides accurate compression ratio for the specific codec parameters
|
||||
size_per_frame_mb = _estimate_frame_size_via_calibration(
|
||||
@@ -1802,7 +1773,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
video_encoder=target_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1844,7 +1815,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
video_encoder=target_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1883,11 +1854,16 @@ def convert_image_to_video_dataset(
|
||||
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos). They are registered as
|
||||
# video features above, so update_video_info populates their (still-empty) info.
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
target_encoder = depth_encoder if img_key in dataset.meta.depth_keys else camera_encoder
|
||||
new_meta.update_video_info(video_key=img_key, video_encoder=target_encoder)
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder=camera_encoder
|
||||
)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
@@ -1914,11 +1890,11 @@ def convert_image_to_video_dataset(
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, video_encoder, encoder_threads = args
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
video_encoder=video_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -1927,8 +1903,7 @@ def _reencode_video_worker(args: tuple) -> Path:
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
@@ -1939,11 +1914,8 @@ def reencode_dataset(
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every RGB video
|
||||
file. If ``None``, re-encoding is skipped for RGB videos.
|
||||
depth_encoder: Target encoder configuration applied to every depth video
|
||||
file. If ``None``, re-encoding is skipped for depth videos.
|
||||
Quantization parameters will not override the ones in the current dataset.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
@@ -1955,35 +1927,23 @@ def reencode_dataset(
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_keys_encoders_dict = {}
|
||||
video_keys_paths_dict = {}
|
||||
|
||||
if camera_encoder is None and depth_encoder is None:
|
||||
raise ValueError("Either camera_encoder or depth_encoder must be provided")
|
||||
video_paths_list = []
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = encoder_config_from_video_info(current_info)
|
||||
target_encoder = depth_encoder if video_key in meta.depth_keys else camera_encoder
|
||||
if target_encoder is None:
|
||||
logging.info(f"No encoder provided for {video_key} video. Skipping re-encoding.")
|
||||
elif current_encoder != target_encoder:
|
||||
video_keys_paths_dict[video_key] = list((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
video_keys_encoders_dict[video_key] = target_encoder
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {target_encoder}. Nothing to do.")
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_keys_paths_dict) == 0:
|
||||
if len(video_paths_list) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
|
||||
worker_args = [
|
||||
(path, encoder, encoder_threads)
|
||||
for video_key, encoder in video_keys_encoders_dict.items()
|
||||
for path in video_keys_paths_dict[video_key]
|
||||
]
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
@@ -1997,15 +1957,10 @@ def reencode_dataset(
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every re-encoded key. Re-encoding only
|
||||
# changes codec/container params, so for depth videos we preserve ``is_depth_map``
|
||||
# and the depth quantization params (``video.depth_min`` / ``video.depth_max`` /
|
||||
# ...), which describe the data rather than the codec and must survive a transcode.
|
||||
# RGB videos pass an empty set: still a refresh, but nothing to preserve.
|
||||
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
|
||||
for video_key, encoder in video_keys_encoders_dict.items():
|
||||
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else set()
|
||||
meta.update_video_info(video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys)
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
@@ -31,12 +31,7 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
@@ -53,7 +48,6 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
@@ -73,22 +67,17 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
path_template = (
|
||||
DEFAULT_DEPTH_PATH
|
||||
if video_encoder is not None and isinstance(video_encoder, DepthEncoderConfig)
|
||||
else DEFAULT_IMAGE_PATH
|
||||
)
|
||||
fpath = path_template.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
video_encoder=video_encoder,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
@@ -108,7 +97,6 @@ class DatasetWriter:
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
@@ -120,11 +108,8 @@ class DatasetWriter:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
camera_encoder: Video encoder settings applied to RGB cameras. When
|
||||
``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings applied to depth cameras, including
|
||||
the quantization parameters. When ``None``,
|
||||
:func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
@@ -136,7 +121,6 @@ class DatasetWriter:
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -161,8 +145,7 @@ class DatasetWriter:
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -212,7 +195,6 @@ class DatasetWriter:
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self._meta.video_keys),
|
||||
depth_video_keys=list(self._meta.depth_keys),
|
||||
temp_dir=self._root,
|
||||
)
|
||||
|
||||
@@ -300,13 +282,10 @@ class DatasetWriter:
|
||||
if use_streaming:
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self._meta.video_keys:
|
||||
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
ep_stats[video_key] = {
|
||||
k: v
|
||||
if k == "count"
|
||||
else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
@@ -321,9 +300,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -534,12 +511,7 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(
|
||||
video_key,
|
||||
video_encoder=self._depth_encoder
|
||||
if video_key in self._meta.depth_keys
|
||||
else self._camera_encoder,
|
||||
)
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -606,14 +578,13 @@ class DatasetWriter:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png/tiff into mp4 videos."""
|
||||
is_depth = video_key in self._meta.depth_keys
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._depth_encoder if is_depth else self._camera_encoder,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.configs.video import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_PIX_FMT,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
DEPTH_QMAX,
|
||||
)
|
||||
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.integer] | NDArray[np.floating],
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
|
||||
resolved_unit = (
|
||||
("m" if np.issubdtype(depth.dtype, np.floating) else "mm") if input_unit == "auto" else input_unit
|
||||
)
|
||||
return depth.astype(np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
video_backend: str | None = "pyav",
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16] | av.VideoFrame:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
video_backend: Video backend to use for encoding. Defaults to "pyav".
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
if isinstance(depth, torch.Tensor):
|
||||
depth = depth.detach().cpu().numpy()
|
||||
|
||||
# Squeeze single-channel dim: (H, W, 1) or (1, H, W) → (H, W)
|
||||
if depth.ndim == 3 and (depth.shape[-1] == 1 or depth.shape[0] == 1):
|
||||
depth = depth.squeeze()
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
quantized = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX).astype(np.uint16, copy=False)
|
||||
|
||||
if video_backend == "pyav":
|
||||
frame = av.VideoFrame.from_ndarray(quantized, format=pix_fmt)
|
||||
write_u16_plane(frame.planes[0], quantized)
|
||||
return frame
|
||||
else:
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | av.VideoFrame | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
pix_fmt: str = DEFAULT_DEPTH_PIX_FMT,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
output_tensor: bool = True,
|
||||
output_channel_last: bool = False,
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32] | torch.Tensor:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit. Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Accepted input layouts :
|
||||
|
||||
- ``(H, W, 1)`` or ``(H, W)`` — single frame with channel-last.
|
||||
- ``(..., 1, H, W)`` — batched frames with channel-first.
|
||||
- ``(..., H, W, 1)`` — batched frames with channel-last.
|
||||
Output layout is determined by ``output_channel_last``.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes in ``[0, DEPTH_QMAX]``. ``np.ndarray``,
|
||||
``av.VideoFrame``, or ``torch.Tensor`` (any integer or float dtype).
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
pix_fmt: Pixel format used to extract the plane from an ``av.VideoFrame``.
|
||||
output_unit: ``"mm"`` returns ``uint16`` millimetres (rint, clip
|
||||
``[0, 65535]``) when returning a numpy array, or ``float32`` mm when
|
||||
``output_tensor=True``. ``"m"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
output_tensor: If True, return a ``torch.Tensor`` instead of a numpy array.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``output_unit`` is not ``"m"`` or ``"mm"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
|
||||
if isinstance(quantized, av.VideoFrame):
|
||||
quantized = quantized.to_ndarray(format=pix_fmt)
|
||||
|
||||
# Compute the scale and offset first.
|
||||
depth_min_m = float(depth_min)
|
||||
depth_max_m = float(depth_max)
|
||||
shift_m = float(shift)
|
||||
if use_log:
|
||||
log_min = math.log(depth_min_m + shift_m)
|
||||
log_max = math.log(depth_max_m + shift_m)
|
||||
scale = (log_max - log_min) / DEPTH_QMAX
|
||||
offset = log_min
|
||||
else:
|
||||
scale = (depth_max_m - depth_min_m) / DEPTH_QMAX
|
||||
offset = depth_min_m
|
||||
|
||||
# ── Torch path: stay on the input device, single fp32 allocation. ────────
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
if quantized.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
|
||||
|
||||
# Single allocation we own; everything else is in-place.
|
||||
buf = quantized.to(dtype=torch.float32, copy=True)
|
||||
buf.mul_(scale).add_(offset)
|
||||
if use_log:
|
||||
buf.exp_().sub_(shift_m)
|
||||
buf.clamp_(depth_min_m, depth_max_m)
|
||||
buf.unsqueeze_(-1) if output_channel_last else buf.unsqueeze_(-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return buf if output_tensor else buf.cpu().numpy()
|
||||
|
||||
# mm path: round + clamp in float32, skipping the uint16 round-trip
|
||||
# when returning a tensor (torch.uint16 is poorly supported).
|
||||
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
|
||||
if output_tensor:
|
||||
return buf
|
||||
return buf.cpu().numpy().astype(np.uint16, copy=False)
|
||||
|
||||
# ── NumPy path: single fp32 allocation, ``out=`` for in-place math. ─────
|
||||
arr = np.asarray(quantized)
|
||||
if arr.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
arr = np.squeeze(arr, axis=-3) if arr.shape[-3] == 1 else np.squeeze(arr, axis=-1)
|
||||
|
||||
buf = np.empty(arr.shape, dtype=np.float32)
|
||||
np.multiply(arr, scale, out=buf)
|
||||
np.add(buf, offset, out=buf)
|
||||
if use_log:
|
||||
np.exp(buf, out=buf)
|
||||
np.subtract(buf, shift_m, out=buf)
|
||||
np.clip(buf, depth_min_m, depth_max_m, out=buf)
|
||||
buf = np.expand_dims(buf, axis=-1) if output_channel_last else np.expand_dims(buf, axis=-3)
|
||||
|
||||
if output_unit == "m":
|
||||
return torch.from_numpy(buf) if output_tensor else buf
|
||||
|
||||
np.multiply(buf, _MM_PER_METRE, out=buf)
|
||||
np.rint(buf, out=buf)
|
||||
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
|
||||
if output_tensor:
|
||||
# torch.uint16 support is very limited; return float32 millimetres.
|
||||
return torch.from_numpy(buf)
|
||||
return buf.astype(np.uint16, copy=False)
|
||||
@@ -96,7 +96,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
depth_output_unit=cfg.dataset.depth_output_unit,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -336,7 +336,7 @@ def validate_feature_image_or_video(
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape, e.g. (C, H, W) or (H, W, C).
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -42,41 +42,10 @@ def safe_stop_image_writer(func):
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
# TODO(CarolinePascal): 4 dimensions RGB-D images
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image.")
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in [np.uint16, np.float32]:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
|
||||
)
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
@@ -102,28 +71,13 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
the save operation.
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -147,7 +101,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import DepthEncoderConfig, VideoEncoderConfig
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -58,10 +58,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -188,9 +186,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
depth_encoder (DepthEncoderConfig | None, optional): Video encoder settings for depth cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
@@ -213,7 +208,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
@@ -254,7 +248,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
|
||||
# Load actual data
|
||||
@@ -280,7 +273,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
depth_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
@@ -288,7 +280,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -324,7 +315,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=self.delta_timestamps,
|
||||
image_transforms=self.image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
return self.reader
|
||||
|
||||
@@ -332,14 +322,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
depth_encoder: DepthEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -536,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool | None = None,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
**card_kwargs,
|
||||
@@ -555,8 +543,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
tag_version: If ``True``, create a Git tag for the current codebase
|
||||
version.
|
||||
push_videos: If ``False``, skip uploading the ``videos/`` directory.
|
||||
private: If ``True``, create a private repository. If ``None``
|
||||
(default), defer to the org default on the Hub (only affects orgs).
|
||||
private: If ``True``, create a private repository.
|
||||
allow_patterns: Glob pattern(s) restricting which files to upload.
|
||||
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
|
||||
of ``upload_folder`` for very large datasets.
|
||||
@@ -658,7 +645,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -691,8 +677,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
@@ -727,7 +711,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
@@ -737,13 +720,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -767,7 +749,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
@@ -797,8 +778,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch-encoding videos.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings for depth cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.depth_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
@@ -826,7 +805,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
|
||||
if obj._requested_root is not None:
|
||||
@@ -846,13 +824,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, camera_encoder, depth_encoder, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder=camera_encoder,
|
||||
depth_encoder=depth_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
|
||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
exclude_images: bool = False,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||
|
||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||
(image or state), filters them based on `exclude_images` and `patterns`, and finally
|
||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
||||
formats them for use with a Hugging Face LeRobot Dataset.
|
||||
|
||||
Args:
|
||||
pipeline: The DataProcessorPipeline to apply.
|
||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
|
||||
exclude_images: If True, image features are dropped entirely from the output.
|
||||
use_videos: If False, image features are excluded.
|
||||
patterns: A sequence of regex patterns to filter action and state features.
|
||||
Image features are not affected by this filter.
|
||||
|
||||
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
|
||||
)
|
||||
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and exclude_images:
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, compiled_patterns):
|
||||
continue
|
||||
|
||||
@@ -24,7 +24,6 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,22 +31,6 @@ FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_pix_fmt_channels(pix_fmt: str) -> int:
|
||||
"""Return the number of components (channels) for *pix_fmt*."""
|
||||
return len(av.VideoFormat(pix_fmt).components)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
@@ -109,7 +92,7 @@ def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Opti
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = float(value)
|
||||
num_val = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
@@ -159,16 +142,6 @@ def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_pix_fmt_channels(pix_fmt: str, channels: int) -> None:
|
||||
"""Ensure *pix_fmt* can carry at least *channels* components."""
|
||||
pix_fmt_channels = get_pix_fmt_channels(pix_fmt)
|
||||
if pix_fmt_channels < channels:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} carries only {pix_fmt_channels} component(s) "
|
||||
f"but the source data has {channels} channel(s)."
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
@@ -183,18 +156,12 @@ def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, Any],
|
||||
channels: int | None = None,
|
||||
) -> None:
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
When given, additionally verify that *pix_fmt* carries as many components as the source data channels.
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
@@ -204,6 +171,4 @@ def check_video_encoder_parameters_pyav(
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
if channels is not None:
|
||||
_check_pix_fmt_channels(pix_fmt, channels)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
|
||||
+32
-122
@@ -14,36 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
"""Sampler over episode frames that stores only per-episode boundaries.
|
||||
|
||||
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
|
||||
instead of materializing a Python list of every frame index.
|
||||
|
||||
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
|
||||
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
|
||||
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
|
||||
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
|
||||
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
|
||||
resumed epoch, `__len__` still reports the full length.
|
||||
|
||||
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
|
||||
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
|
||||
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
|
||||
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
|
||||
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
|
||||
starts (e.g. on resume or in tests).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
@@ -52,125 +30,57 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
):
|
||||
"""
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: Start index of each episode in the dataset.
|
||||
dataset_to_indices: End index of each episode in the dataset.
|
||||
episode_indices_to_use: Episode indices to use; None means all.
|
||||
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
seed: Seed the permutation is derived from (together with the epoch).
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||
if from_indices.shape != to_indices.shape:
|
||||
raise ValueError(
|
||||
f"dataset_from_indices and dataset_to_indices must have the same length, "
|
||||
f"got {len(from_indices)} and {len(to_indices)}"
|
||||
)
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
used = np.ones(len(from_indices), dtype=bool)
|
||||
if episode_indices_to_use is not None:
|
||||
used = np.zeros(len(from_indices), dtype=bool)
|
||||
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
|
||||
|
||||
starts = from_indices + drop_n_first_frames
|
||||
lengths = to_indices - drop_n_last_frames - starts
|
||||
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
to_indices[episode_idx] - from_indices[episode_idx],
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
used &= lengths > 0
|
||||
if not used.any():
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self._starts = starts[used]
|
||||
self._cum_lengths = np.cumsum(lengths[used])
|
||||
self._num_frames = int(self._cum_lengths[-1])
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||
|
||||
def load_state_dict(self, state: dict) -> None:
|
||||
self._epoch = state["epoch"]
|
||||
self._start_index = state["start_index"]
|
||||
|
||||
def _epoch_generator(self, epoch: int) -> torch.Generator:
|
||||
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
|
||||
# and reproduces identically on every rank without touching the global RNG.
|
||||
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
|
||||
return torch.Generator().manual_seed(epoch_seed)
|
||||
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||
return int(self._starts[episode]) + position_in_episode
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||
epoch, start = self._epoch, self._start_index
|
||||
self._epoch += 1
|
||||
self._start_index = 0
|
||||
return self._iter_epoch(epoch, start)
|
||||
|
||||
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(int(order[k]))
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(k)
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_frames
|
||||
|
||||
|
||||
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||
|
||||
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||
per epoch (`even_batches` padding included). The start index provably stays below
|
||||
`num_frames`; the `min` is defensive.
|
||||
|
||||
Assumptions (resume is only sample-exact when they hold):
|
||||
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
|
||||
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||
the boundary is off.
|
||||
"""
|
||||
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
|
||||
return {"epoch": epoch, "start_index": start_index}
|
||||
return len(self.indices)
|
||||
|
||||
@@ -87,14 +87,11 @@ DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
IMAGE_FILE_PATTERN = "frame-{frame_index:06d}.png"
|
||||
DEPTH_FILE_PATTERN = "frame-{frame_index:06d}.tiff"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/" + IMAGE_FILE_PATTERN
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/" + DEPTH_FILE_PATTERN
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -39,16 +39,11 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
from .pyav_utils import get_pix_fmt_channels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +53,6 @@ def decode_video_frames(
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -70,35 +64,23 @@ def decode_video_frames(
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||
return_uint8 (bool): For RGB videos, if True return raw uint8 frames without float32 normalization.
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
is_depth (bool): Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames (RGB: float32 in [0,1] by default, or uint8 if return_uint8=True, Depth: uint12).
|
||||
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True
|
||||
)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -109,7 +91,6 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
is_depth: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
@@ -128,9 +109,8 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
|
||||
Otherwise, return float32 in [0, 1] range.
|
||||
is_depth: Set to True if the video is a depth map (1 channel, uint12).
|
||||
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||
[0, 1] range.
|
||||
|
||||
Returns:
|
||||
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||
@@ -160,13 +140,9 @@ def decode_video_frames_pyav(
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
if is_depth:
|
||||
arr = frame.to_ndarray(format="gray12le") # (H, W) uint12
|
||||
loaded_frames.append(torch.from_numpy(arr).unsqueeze(0).contiguous())
|
||||
else:
|
||||
arr = frame.to_ndarray(format="rgb24") # (H, W, 3)
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
@@ -209,7 +185,7 @@ def decode_video_frames_pyav(
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
|
||||
if return_uint8 or is_depth:
|
||||
if return_uint8:
|
||||
return closest_frames
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
@@ -430,38 +406,17 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Encode a directory of image frames into an MP4 video.
|
||||
|
||||
When ``video_encoder`` is a :class:`~lerobot.configs.video.DepthEncoderConfig`,
|
||||
frames are read from ``.tiff`` files and quantized to 12-bit depth codes using the
|
||||
encoder's ``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``; otherwise ``.png``
|
||||
RGB frames are encoded directly.
|
||||
|
||||
Args:
|
||||
imgs_dir: Directory containing the frames to encode, named ``frame-000000``
|
||||
onwards (``.png`` for RGB, ``.tiff`` for depth).
|
||||
video_path: Output path for the encoded ``.mp4`` file.
|
||||
fps: Frame rate of the output video.
|
||||
video_encoder: Encoder settings (codec, pixel format, quality, ...). When
|
||||
``None``, :func:`camera_encoder_defaults` is used. Pass a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig` to encode depth frames.
|
||||
encoder_threads: Per-encoder thread count forwarded to the codec. ``None``
|
||||
lets the codec decide.
|
||||
log_level: libav log level to set while encoding, or ``None`` to leave the
|
||||
current logging configuration unchanged.
|
||||
overwrite: When ``False`` and ``video_path`` already exists, skip encoding and
|
||||
log a warning. When ``True``, re-encode and replace the existing file.
|
||||
"""
|
||||
if video_encoder is None:
|
||||
video_encoder = camera_encoder_defaults()
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -473,19 +428,17 @@ def encode_video_frames(
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get input frames
|
||||
is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
suffix = ".png" if not is_depth else ".tiff"
|
||||
template = "frame-" + ("[0-9]" * 6) + suffix
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images with suffix {suffix} found in {imgs_dir}.")
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -502,19 +455,8 @@ def encode_video_frames(
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
with Image.open(input_data) as input_image:
|
||||
if is_depth:
|
||||
input_frame = quantize_depth(
|
||||
np.array(input_image),
|
||||
depth_min=video_encoder.depth_min,
|
||||
depth_max=video_encoder.depth_max,
|
||||
shift=video_encoder.shift,
|
||||
use_log=video_encoder.use_log,
|
||||
pix_fmt=video_encoder.pix_fmt,
|
||||
video_backend="pyav",
|
||||
)
|
||||
else:
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
input_image = input_image.convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
@@ -535,32 +477,23 @@ def encode_video_frames(
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
start_time_s: float | None = None,
|
||||
end_time_s: float | None = None,
|
||||
) -> None:
|
||||
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
|
||||
"""Re-encode a video file using the given encoder configuration.
|
||||
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
video_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
start_time_s: When set, trim the output to start at this timestamp (seconds).
|
||||
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
|
||||
"""
|
||||
|
||||
video_encoder = video_encoder or camera_encoder_defaults()
|
||||
|
||||
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
|
||||
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
|
||||
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
|
||||
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
@@ -570,9 +503,9 @@ def reencode_video(
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = video_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = video_encoder.vcodec
|
||||
pix_fmt = video_encoder.pix_fmt
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
@@ -593,10 +526,6 @@ def reencode_video(
|
||||
width = int(in_stream.width)
|
||||
height = int(in_stream.height)
|
||||
|
||||
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
|
||||
if start_time_s is not None:
|
||||
src.seek(int(start_time_s * av.time_base), backward=True)
|
||||
|
||||
with av.open(
|
||||
tmp_output_video_path,
|
||||
mode="w",
|
||||
@@ -610,14 +539,7 @@ def reencode_video(
|
||||
out_stream.height = height
|
||||
|
||||
for frame in src.decode(in_stream):
|
||||
frame_time_s = frame.time
|
||||
if start_time_s is not None and frame_time_s < start_time_s:
|
||||
continue
|
||||
if end_time_s is not None and frame_time_s >= end_time_s:
|
||||
break
|
||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||
if start_time_s is not None:
|
||||
frame.pts = None # reset timestamps so the trimmed output starts at t=0
|
||||
packet = out_stream.encode(frame)
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
@@ -754,21 +676,22 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
video_encoder: VideoEncoderConfig,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.video_encoder = video_encoder
|
||||
self.is_depth = isinstance(video_encoder, DepthEncoderConfig)
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -793,12 +716,12 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
# Ensure HWC uint8 numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] in (1, 3):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if not self.is_depth and frame_data.dtype != np.uint8:
|
||||
if frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
@@ -806,29 +729,15 @@ class _CameraEncoderThread(threading.Thread):
|
||||
height, width = frame_data.shape[:2]
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(
|
||||
self.video_encoder.vcodec,
|
||||
self.fps,
|
||||
options=self.video_encoder.get_codec_options(self.encoder_threads, as_strings=True),
|
||||
)
|
||||
output_stream.pix_fmt = self.video_encoder.pix_fmt
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
if not self.is_depth:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
else:
|
||||
video_frame = quantize_depth(
|
||||
frame_data,
|
||||
depth_min=self.video_encoder.depth_min,
|
||||
depth_max=self.video_encoder.depth_max,
|
||||
shift=self.video_encoder.shift,
|
||||
use_log=self.video_encoder.use_log,
|
||||
video_backend=self.video_encoder.video_backend,
|
||||
)
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
@@ -887,26 +796,21 @@ class StreamingVideoEncoder:
|
||||
self,
|
||||
fps: int,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder: Video encoder settings applied to all RGB cameras.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
When ``None``, :func:`camera_encoder_defaults` is used.
|
||||
depth_encoder: Video encoder settings applied to all depth cameras,
|
||||
including the depth quantization parameters. When ``None``,
|
||||
:func:`depth_encoder_defaults` is used.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._depth_encoder = depth_encoder or depth_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
|
||||
@@ -919,25 +823,18 @@ class StreamingVideoEncoder:
|
||||
self._episode_active = False
|
||||
self._closed = False
|
||||
|
||||
def start_episode(
|
||||
self, video_keys: list[str], temp_dir: Path, depth_video_keys: list[str] | None = None
|
||||
) -> None:
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
depth_video_keys: List of video or image feature keys that carry depth maps (e.g.
|
||||
["observation.images.laptop_depth"]). Defaults to ``[]`` (no depth keys).
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
if depth_video_keys is None:
|
||||
depth_video_keys = []
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
@@ -946,15 +843,17 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
encoder = self._depth_encoder if video_key in depth_video_keys else self._camera_encoder
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
video_encoder=encoder,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self._encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -1161,23 +1060,15 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
video_encoder: VideoEncoderConfig | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
video_encoder: If provided, record the exact encoder settings used to encode this
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself. When a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig` is passed, the depth
|
||||
quantization parameters (``depth_min`` / ``depth_max`` / ``shift`` /
|
||||
``use_log``) are recorded so frames can be dequantized on read.
|
||||
|
||||
Returns:
|
||||
The ``video.*`` / ``audio.*`` info dict, including ``is_depth_map`` which is
|
||||
``True`` only when ``video_encoder`` is a
|
||||
:class:`~lerobot.configs.video.DepthEncoderConfig`.
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
@@ -1195,10 +1086,13 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
video_info["video.channels"] = get_pix_fmt_channels(video_stream.pix_fmt)
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
||||
video_info["video.channels"] = pixel_channels
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
@@ -1207,18 +1101,27 @@ def get_video_info(
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if video_encoder is not None:
|
||||
for field_name, field_value in asdict(video_encoder).items():
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
@@ -1279,13 +1182,10 @@ class VideoEncodingManager:
|
||||
img_dir = self.dataset.root / "images"
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
tiff_files = list(img_dir.rglob("*.tiff"))
|
||||
if len(png_files) == 0 and len(tiff_files) == 0:
|
||||
if len(png_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logger.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Images directory is not empty, containing {len(png_files)} PNG and {len(tiff_files)} TIFF files"
|
||||
)
|
||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -57,7 +57,6 @@ from .pretrained import PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -158,10 +157,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
|
||||
return MolmoAct2Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -216,8 +211,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "molmoact2":
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -422,7 +415,6 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -440,14 +432,6 @@ def make_pre_post_processors(
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -29,7 +29,6 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.__version__ import __version__
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
@@ -39,67 +38,6 @@ from .utils import log_model_loading_keys
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
|
||||
def _build_card_context(
|
||||
cfg: TrainPipelineConfig | None,
|
||||
dataset_repo_id: str | None,
|
||||
input_features: dict | None,
|
||||
output_features: dict | None,
|
||||
) -> dict:
|
||||
"""Collect optional data for the model-card template.
|
||||
|
||||
Returns plain values only (no Markdown) — the template in
|
||||
``lerobot/templates/lerobot_modelcard_template.md`` decides how and whether to show
|
||||
each one. Everything is best-effort: anything unavailable is left empty/None and the
|
||||
template simply skips that section, so this never breaks a Hub push.
|
||||
"""
|
||||
context = {
|
||||
"training": None,
|
||||
"input_features": input_features or {},
|
||||
"output_features": output_features or {},
|
||||
"dataset": None,
|
||||
"robot_type": None,
|
||||
"cameras": [],
|
||||
}
|
||||
|
||||
if cfg is not None:
|
||||
optimizer = getattr(cfg, "optimizer", None)
|
||||
context["training"] = {
|
||||
"steps": cfg.steps,
|
||||
"batch_size": cfg.batch_size,
|
||||
"seed": cfg.seed,
|
||||
"optimizer": getattr(optimizer, "type", None) if optimizer else None,
|
||||
"lr": getattr(optimizer, "lr", None) if optimizer else None,
|
||||
"lerobot_version": __version__,
|
||||
}
|
||||
|
||||
if dataset_repo_id:
|
||||
dataset_cfg = getattr(cfg, "dataset", None)
|
||||
try:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(
|
||||
dataset_repo_id,
|
||||
root=getattr(dataset_cfg, "root", None),
|
||||
revision=getattr(dataset_cfg, "revision", None),
|
||||
)
|
||||
context["dataset"] = {
|
||||
"repo_id": dataset_repo_id,
|
||||
"episodes": meta.total_episodes,
|
||||
"frames": meta.total_frames,
|
||||
"fps": meta.fps,
|
||||
"tasks": [str(task) for task in meta.tasks.index],
|
||||
}
|
||||
context["robot_type"] = meta.robot_type
|
||||
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
|
||||
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
|
||||
logging.warning(
|
||||
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
|
||||
f"omitted from the model card. ({e})"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
noise: Tensor | None
|
||||
|
||||
@@ -290,7 +228,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
@@ -308,20 +246,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def generate_model_card(
|
||||
self,
|
||||
dataset_repo_id: str,
|
||||
model_type: str,
|
||||
license: str | None,
|
||||
tags: list[str] | None,
|
||||
cfg: TrainPipelineConfig | None = None,
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
base_model_mapping = {
|
||||
"smolvla": "lerobot/smolvla_base",
|
||||
"pi0": "lerobot/pi0_base",
|
||||
"pi05": "lerobot/pi05_base",
|
||||
"pi0_fast": "lerobot/pi0fast-base",
|
||||
"xvla": "lerobot/xvla-base",
|
||||
}
|
||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
||||
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
@@ -330,20 +257,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
base_model=base_model_mapping.get(model_type),
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
context = _build_card_context(
|
||||
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
|
||||
)
|
||||
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
|
||||
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
|
||||
context["base_model"] = base_model_mapping.get(model_type)
|
||||
|
||||
template_card = (
|
||||
files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8")
|
||||
)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card, **context)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
|
||||
|
||||
@@ -126,8 +126,7 @@ def prepare_observation_for_inference(
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
if observation[name].dtype == torch.uint8:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_vla_jepa_README.md
|
||||
@@ -1,23 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
@@ -1,337 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
else:
|
||||
|
||||
class ModelMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class ConfigMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
register_to_config = lambda f: f # noqa: E731
|
||||
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
require_package("diffusers", extra="vla_jepa")
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||
# different action or state dimensionality, the input/output projection layers must be
|
||||
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||
reinit_modules: list[str] | None = None
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
gripper_dim: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
self.action_dim = self.action_feature.shape[0]
|
||||
if self.robot_state_feature is not None:
|
||||
self.state_dim = self.robot_state_feature.shape[0]
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,629 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoVideoProcessor = None
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
require_package("transformers", extra="vla_jepa")
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = config.jepa_tubelet_size
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
if dataset_meta := kwargs.get("dataset_meta"):
|
||||
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||
ds_features = dataset_meta.features
|
||||
if OBS_STATE in ds_features:
|
||||
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||
if ACTION in ds_features:
|
||||
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||
logs["loss"] = total_loss.detach().item()
|
||||
return total_loss, logs
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
reinit_prefixes = model.config.reinit_modules
|
||||
if not reinit_prefixes:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
reinitialized: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||
)
|
||||
reinitialized.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if reinitialized:
|
||||
logging.warning(
|
||||
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
return model
|
||||
@@ -1,155 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||
class ClipActionsProcessorStep(ProcessorStep):
|
||||
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition = dict(transition)
|
||||
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes a gripper dimension after unnormalization.
|
||||
|
||||
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
|
||||
Only applied when action has more dimensions than gripper_dim.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(
|
||||
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
)
|
||||
)
|
||||
if config.binarize_gripper_action:
|
||||
output_steps.append(
|
||||
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
@@ -1,418 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
@@ -32,6 +32,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -280,11 +281,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -342,108 +338,30 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
step_entry: dict[str, Any] = {}
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -451,110 +369,31 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
return pipeline_config
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
config["steps"].append(step_entry)
|
||||
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -738,54 +577,12 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
pipeline = cls(
|
||||
return cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -869,7 +666,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -889,7 +688,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -903,14 +702,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -972,41 +766,26 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
return steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
# 3. Load step state if available
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
|
||||
steps.append(step_instance)
|
||||
|
||||
return steps, override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1317,7 +1096,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1368,9 +1147,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("relative_actions_processor")
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||||
|
||||
@@ -20,14 +20,12 @@ from .factory import (
|
||||
make_reward_pre_post_processors as make_reward_pre_post_processors,
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
"TOPRewardConfig",
|
||||
# Base class
|
||||
|
||||
@@ -25,7 +25,6 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig
|
||||
|
||||
@@ -39,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm", "robometer", "topreward".
|
||||
"sarm", "topreward".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
@@ -55,10 +54,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "robometer":
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
return RobometerRewardModel
|
||||
elif name == "topreward":
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
@@ -79,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm", "robometer", "topreward".
|
||||
"reward_classifier", "sarm", "topreward".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -92,8 +87,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
elif reward_type == "robometer":
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
else:
|
||||
@@ -175,13 +168,6 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(reward_cfg, RobometerConfig):
|
||||
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
|
||||
|
||||
return make_robometer_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, TOPRewardConfig):
|
||||
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_robometer import RobometerConfig
|
||||
from .modeling_robometer import RobometerRewardModel
|
||||
from .processor_robometer import make_robometer_pre_post_processors
|
||||
|
||||
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]
|
||||
@@ -1,320 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame Robometer progress and success curves for a LeRobot dataset.
|
||||
|
||||
For each episode, builds per-frame sub-samples using the frame-steps
|
||||
strategy from the Robometer eval server: for each original frame ``t``,
|
||||
linspace-subsample ``[0, t]`` into ``K`` frames (default 4, matching
|
||||
``NUM_SUBSAMPLED_FRAMES`` in the eval server), run one forward through
|
||||
the Robometer processor + model, and keep the last-frame progress value.
|
||||
All sub-samples are the same size ``K`` so they batch cleanly.
|
||||
|
||||
The parquet uses the same schema as SARM's
|
||||
:mod:`lerobot.rewards.sarm.compute_rabc_weights` so existing consumers —
|
||||
:class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
|
||||
``progress_sparse``) and the progress-overlay script in
|
||||
``examples/dataset/create_progress_videos.py`` — work without modification.
|
||||
|
||||
Usage:
|
||||
# Dense per-frame progress for one episode
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--episodes 0
|
||||
|
||||
# All episodes with batching
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--batch-size 16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
DEFAULT_OUTPUT_FILENAME = "robometer_progress.parquet"
|
||||
|
||||
# Upstream Robometer eval server uses K=4 for frame-steps sub-samples.
|
||||
DEFAULT_NUM_SUBSAMPLED_FRAMES = 4
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read ``reward_model_path`` from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task(sample: dict[str, Any], default: str) -> str:
|
||||
"""Best-effort task extraction from a dataset sample."""
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task:
|
||||
return task
|
||||
return default
|
||||
|
||||
|
||||
def _build_subsample_indices(num_frames: int, num_subsampled_frames: int) -> list[np.ndarray]:
|
||||
"""Frame-steps linspace expansion.
|
||||
|
||||
For each ``t in [0, num_frames - 1]`` returns ``num_subsampled_frames``
|
||||
indices from ``np.linspace(0, t, num_subsampled_frames)`` — the first
|
||||
and last frames are always included. Each entry is a fixed-size array
|
||||
so the model can batch them.
|
||||
"""
|
||||
return [np.linspace(0, t, num_subsampled_frames).round().astype(np.int64) for t in range(num_frames)]
|
||||
|
||||
|
||||
def compute_robometer_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
output_path: str | None = None,
|
||||
device: str = "cuda",
|
||||
batch_size: int = 32,
|
||||
num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES,
|
||||
episodes: list[int] | None = None,
|
||||
image_key: str | None = None,
|
||||
) -> Path:
|
||||
"""Run Robometer over a dataset and write per-frame progress + success."""
|
||||
logging.info(f"Loading Robometer: {reward_model_path}")
|
||||
config = RobometerConfig(pretrained_path=reward_model_path, device=device)
|
||||
if image_key is not None:
|
||||
config.image_key = image_key
|
||||
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
|
||||
model.to(device).eval()
|
||||
|
||||
encoder = RobometerEncoderProcessorStep(
|
||||
base_model_id=config.base_model_id,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=num_subsampled_frames,
|
||||
use_multi_image=config.use_multi_image,
|
||||
use_per_frame_progress_token=config.use_per_frame_progress_token,
|
||||
)
|
||||
|
||||
image_key = config.image_key
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
|
||||
logging.info(f"Processing {len(episode_indices)} episode(s)")
|
||||
|
||||
all_index: list[int] = []
|
||||
all_episode: list[int] = []
|
||||
all_frame: list[int] = []
|
||||
all_progress: list[float] = []
|
||||
|
||||
for episode_idx in tqdm(episode_indices, desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = int(ep["dataset_from_index"])
|
||||
ep_end = int(ep["dataset_to_index"])
|
||||
num_frames = ep_end - ep_start
|
||||
if num_frames <= 0:
|
||||
continue
|
||||
|
||||
first_sample = dataset[ep_start]
|
||||
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
|
||||
|
||||
ep_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
|
||||
|
||||
sub_indices = _build_subsample_indices(num_frames, num_subsampled_frames)
|
||||
|
||||
progress_per_frame = np.zeros(num_frames, dtype=np.float32)
|
||||
|
||||
for start in tqdm(range(0, num_frames, batch_size), desc=f" Ep {episode_idx}", leave=False):
|
||||
end = min(start + batch_size, num_frames)
|
||||
frames_batch = torch.stack([ep_frames[sub_indices[i]] for i in range(start, end)])
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {image_key: frames_batch},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
|
||||
}
|
||||
encoded = encoder(transition)
|
||||
obs = encoded[TransitionKey.OBSERVATION]
|
||||
batch = {
|
||||
key: value.to(device) if isinstance(value, torch.Tensor) else value
|
||||
for key, value in obs.items()
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
rewards = model.compute_reward(batch)
|
||||
progress_per_frame[start:end] = rewards.cpu().numpy()
|
||||
|
||||
for local in range(num_frames):
|
||||
all_index.append(ep_start + local)
|
||||
all_episode.append(episode_idx)
|
||||
all_frame.append(local)
|
||||
all_progress.append(float(progress_per_frame[local]))
|
||||
|
||||
if device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"index": np.asarray(all_index, dtype=np.int64),
|
||||
"episode_index": np.asarray(all_episode, dtype=np.int64),
|
||||
"frame_index": np.asarray(all_frame, dtype=np.int64),
|
||||
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
|
||||
}
|
||||
).replace_schema_metadata({b"reward_model_path": reward_model_path.encode()})
|
||||
|
||||
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Saved {len(table)} frame values to {out}")
|
||||
|
||||
progress_arr = np.asarray(all_progress, dtype=np.float32)
|
||||
if progress_arr.size:
|
||||
logging.info(
|
||||
f"Progress: mean={float(progress_arr.mean()):.4f}, "
|
||||
f"std={float(progress_arr.std()):.4f}, "
|
||||
f"min={float(progress_arr.min()):.4f}, "
|
||||
f"max={float(progress_arr.max()):.4f}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame Robometer progress curves for RA-BC weighting.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Dense per-frame progress for one episode
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--episodes 0
|
||||
|
||||
# All episodes, smaller batches for memory-constrained GPUs
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--batch-size 16
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path", type=str, default=None, help="Robometer checkpoint repo id or local path."
|
||||
)
|
||||
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=32, help="Sub-samples per Qwen forward (default: 32)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-subsampled-frames",
|
||||
type=int,
|
||||
default=DEFAULT_NUM_SUBSAMPLED_FRAMES,
|
||||
help=f"Frames per sub-sample (default: {DEFAULT_NUM_SUBSAMPLED_FRAMES}, matches eval server).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes", type=int, nargs="+", default=None, help="Process only these episode indices."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key", type=str, default=None, help="Image observation key (default: from config)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
reward_model_path = args.reward_model_path
|
||||
if reward_model_path is None:
|
||||
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
|
||||
parquet_path = Path(temp_dataset.root) / DEFAULT_OUTPUT_FILENAME
|
||||
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
|
||||
if reward_model_path:
|
||||
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"--reward-model-path is required (no existing parquet with model metadata found)."
|
||||
)
|
||||
|
||||
output_path = compute_robometer_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=reward_model_path,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
batch_size=args.batch_size,
|
||||
num_subsampled_frames=args.num_subsampled_frames,
|
||||
episodes=args.episodes,
|
||||
image_key=args.image_key,
|
||||
)
|
||||
|
||||
print(f"\nRobometer progress saved to: {output_path}")
|
||||
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = DEFAULT_OUTPUT_FILENAME
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
"Successfully uploaded to: "
|
||||
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,158 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
else:
|
||||
AutoConfig = None # type: ignore[assignment]
|
||||
AutoTokenizer = None # type: ignore[assignment]
|
||||
|
||||
|
||||
# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time.
|
||||
# The order is part of the data contract: upstream resized ``embed_tokens``
|
||||
# after adding these tokens in this exact order, so changing the set or order
|
||||
# would silently misalign the saved embedding rows with their token ids.
|
||||
# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream
|
||||
# heads (never read at inference) but still occupy rows the checkpoint expects.
|
||||
ROBOMETER_SPECIAL_TOKENS = (
|
||||
"<|split_token|>",
|
||||
"<|reward_token|>",
|
||||
"<|pref_token|>",
|
||||
"<|sim_token|>",
|
||||
"<|prog_token|>",
|
||||
)
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("robometer")
|
||||
@dataclass
|
||||
class RobometerConfig(RewardModelConfig):
|
||||
"""Configuration for the Robometer reward model."""
|
||||
|
||||
pretrained_path: str | None = "lerobot/Robometer-4B"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
|
||||
max_frames: int | None = 8
|
||||
reward_output: str = "progress" # "progress" or "success"
|
||||
success_threshold: float = 0.5
|
||||
|
||||
license: str | None = "apache-2.0"
|
||||
tags: list[str] | None = field(
|
||||
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
|
||||
)
|
||||
|
||||
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
torch_dtype: str = "bfloat16"
|
||||
use_multi_image: bool = True
|
||||
use_per_frame_progress_token: bool = True
|
||||
average_temporal_patches: bool = True
|
||||
frame_pooling: str = "mean" # "mean" | "boundary" | "attention"
|
||||
frame_pooling_attn_temperature: float = 1.0
|
||||
progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete"
|
||||
progress_discrete_bins: int = 10
|
||||
|
||||
# Serialised Qwen backbone config (post-resize). Always populated by
|
||||
# ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it
|
||||
# is non-empty after construction. Saved into ``config.json`` automatically
|
||||
# by the base ``_save_pretrained``.
|
||||
vlm_config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.reward_output not in {"progress", "success"}:
|
||||
raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}")
|
||||
if self.max_frames is not None and self.max_frames < 1:
|
||||
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
|
||||
if self.frame_pooling not in {"mean", "boundary", "attention"}:
|
||||
raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}")
|
||||
if self.frame_pooling_attn_temperature <= 0:
|
||||
raise ValueError("frame_pooling_attn_temperature must be > 0")
|
||||
if self.progress_loss_type not in {"l1", "l2", "discrete"}:
|
||||
raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}")
|
||||
if self.use_per_frame_progress_token and not self.use_multi_image:
|
||||
raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True")
|
||||
|
||||
if self.image_key not in self.input_features:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
|
||||
self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
|
||||
# Deterministically populate ``vlm_config`` so it is non-empty after
|
||||
# construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives
|
||||
# ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize
|
||||
# vocab the published ``Robometer-4B`` checkpoint was saved with.
|
||||
if not self.vlm_config:
|
||||
require_package("transformers", extra="robometer")
|
||||
vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict()
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
|
||||
text_config = vlm.get("text_config")
|
||||
if not isinstance(text_config, dict):
|
||||
raise ValueError(
|
||||
f"Backbone config for {self.base_model_id!r} has no nested `text_config`; "
|
||||
"Robometer expects a Qwen-VL-style config."
|
||||
)
|
||||
text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)
|
||||
self.vlm_config = vlm
|
||||
|
||||
@property
|
||||
def use_discrete_progress(self) -> bool:
|
||||
"""Whether the progress head outputs distribution logits over bins."""
|
||||
return self.progress_loss_type.lower() == "discrete"
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self):
|
||||
"""Reconstruct the Qwen backbone config from :attr:`vlm_config`."""
|
||||
require_package("transformers", extra="robometer")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
model_type = config_dict.pop("model_type", None)
|
||||
if model_type is None:
|
||||
raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config")
|
||||
return AutoConfig.for_model(model_type, **config_dict)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.image_key not in self.input_features:
|
||||
raise ValueError(f"Robometer requires image input feature {self.image_key!r}")
|
||||
@@ -1,481 +0,0 @@
|
||||
# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons.
|
||||
|
||||
Paper: https://arxiv.org/abs/2603.02115
|
||||
Project: https://robometer.github.io
|
||||
Original code: https://github.com/aliang8/robometer
|
||||
Model: https://huggingface.co/robometer/Robometer-4B
|
||||
|
||||
Robometer is a general-purpose, video-language-input reward model built on
|
||||
``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction
|
||||
objective:
|
||||
|
||||
- A frame-level progress loss anchoring reward magnitude on expert data.
|
||||
- A trajectory-comparison preference loss imposing global ordering constraints
|
||||
across trajectories sharing the same instruction.
|
||||
|
||||
To support downstream RL it also predicts a frame-level binary success. The
|
||||
training prompt inserts three learnable tokens:
|
||||
|
||||
- ``<|prog_token|>`` after each frame to read per-frame progress and success.
|
||||
- ``<|pref_token|>`` at the end to read pairwise preference (training-only).
|
||||
- ``<|split_token|>`` between two trajectories in preference samples
|
||||
(training-only).
|
||||
|
||||
Progress is modeled as a categorical distribution over ``progress_discrete_bins``
|
||||
uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate
|
||||
is recovered as the softmax-weighted mean of those centers — see
|
||||
:func:`convert_bins_to_continuous`.
|
||||
|
||||
This LeRobot port is **inference-only**: the preference head is preserved in
|
||||
the state dict for byte-equivalence with the published ``Robometer-4B``
|
||||
checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`,
|
||||
which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd
|
||||
success probability depending on :attr:`RobometerConfig.reward_output`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModelForImageTextToText
|
||||
else:
|
||||
AutoModelForImageTextToText = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Namespace for Robometer's pre-encoded Qwen-VL observation tensors.
|
||||
ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer."
|
||||
ROBOMETER_QWEN_INPUT_KEYS = (
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"pixel_values_videos",
|
||||
"image_grid_thw",
|
||||
"video_grid_thw",
|
||||
"second_per_grid_ts",
|
||||
"mm_token_type_ids",
|
||||
)
|
||||
ROBOMETER_METADATA_KEYS = (
|
||||
"prog_token_id",
|
||||
"vision_start_token_id",
|
||||
"vision_end_token_id",
|
||||
"video_merge_size",
|
||||
)
|
||||
ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS
|
||||
|
||||
|
||||
def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor:
|
||||
"""Collapse per-bin logits into a single value in ``[0, 1]``.
|
||||
|
||||
The discrete progress head outputs ``num_bins`` logits per frame. Bins are
|
||||
evenly spaced centers in ``[0, 1]``; the continuous prediction is the
|
||||
softmax-weighted mean of those centers.
|
||||
"""
|
||||
bin_probs = torch.softmax(bin_logits, dim=-1)
|
||||
num_bins = bin_logits.shape[-1]
|
||||
bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype)
|
||||
return (bin_probs * bin_centers).sum(dim=-1)
|
||||
|
||||
|
||||
def _squeeze_last_safe(x: Tensor) -> Tensor:
|
||||
"""Drop a trailing singleton dim only when present."""
|
||||
return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype:
|
||||
dtype = getattr(torch, name, None)
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
raise ValueError(f"Unknown torch dtype: {name!r}")
|
||||
|
||||
|
||||
class RobometerPredictionHead(nn.Sequential):
|
||||
"""Small MLP head used for Robometer's progress / success / preference outputs."""
|
||||
|
||||
def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None:
|
||||
layers: list[nn.Module] = [
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim // 2, output_size),
|
||||
]
|
||||
if with_sigmoid:
|
||||
layers.append(nn.Sigmoid())
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
def decode_progress_outputs(
|
||||
progress_logits: Tensor | None,
|
||||
success_logits: Tensor | None,
|
||||
*,
|
||||
is_discrete_mode: bool,
|
||||
) -> dict[str, list[list[float]]]:
|
||||
"""Decode RBM head outputs into per-frame floats.
|
||||
|
||||
Args:
|
||||
progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
|
||||
success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities.
|
||||
is_discrete_mode: if True the progress logits get a softmax over bins
|
||||
and are projected onto bin centers via :func:`convert_bins_to_continuous`.
|
||||
|
||||
Returns:
|
||||
Dict with ``progress_pred`` and ``success_probs``, each a list of
|
||||
length ``B`` of per-frame float lists.
|
||||
"""
|
||||
progress_pred: list[list[float]] = []
|
||||
success_probs: list[list[float]] = []
|
||||
|
||||
if progress_logits is not None:
|
||||
for sample_logits in progress_logits:
|
||||
if is_discrete_mode:
|
||||
continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu())
|
||||
progress_pred.append(continuous.flatten().tolist())
|
||||
else:
|
||||
progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist())
|
||||
|
||||
if success_logits is not None:
|
||||
for sample_logits in success_logits:
|
||||
success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist())
|
||||
|
||||
return {"progress_pred": progress_pred, "success_probs": success_probs}
|
||||
|
||||
|
||||
class RobometerRewardModel(PreTrainedRewardModel):
|
||||
"""Robometer (RBM) reward model — inference-only LeRobot port.
|
||||
|
||||
Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three
|
||||
prediction heads from the paper (progress, success, preference). At
|
||||
inference time only the progress and success heads are queried; the
|
||||
preference head is kept on the module so the published ``Robometer-4B``
|
||||
safetensors load unchanged.
|
||||
"""
|
||||
|
||||
name = "robometer"
|
||||
config_class = RobometerConfig
|
||||
|
||||
def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None:
|
||||
require_package("transformers", extra="robometer")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Two backbone-build paths (EO-1 style, branched on ``pretrained_path``):
|
||||
#
|
||||
# - Fresh training (``pretrained_path is None``): download the base
|
||||
# Qwen weights and resize the embed table to match
|
||||
# ``vlm_config.text_config.vocab_size`` — populated deterministically
|
||||
# in ``RobometerConfig.__post_init__`` as
|
||||
# ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``
|
||||
#
|
||||
# - Loading a saved checkpoint (``pretrained_path`` is set): rebuild
|
||||
# the empty architecture from ``vlm_config`` via
|
||||
# ``AutoModelForImageTextToText.from_config`` so the subsequent
|
||||
# ``model.safetensors`` load is a direct fill of the right shape —
|
||||
# no redundant Qwen weight download.
|
||||
torch_dtype = _torch_dtype(config.torch_dtype)
|
||||
if config.pretrained_path is None:
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
config.base_model_id,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
target_vocab = config.vlm_config["text_config"]["vocab_size"]
|
||||
self.model.resize_token_embeddings(target_vocab)
|
||||
else:
|
||||
self.model = AutoModelForImageTextToText.from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`.
|
||||
# Falls back to the top-level `hidden_size` so future non-multimodal
|
||||
# variants would still resolve.
|
||||
backbone_config = self.model.config
|
||||
text_config = getattr(backbone_config, "text_config", None)
|
||||
hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None
|
||||
if hidden_size is None:
|
||||
hidden_size = getattr(backbone_config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise AttributeError(
|
||||
f"Could not infer hidden_size from backbone config of {config.base_model_id}"
|
||||
)
|
||||
hidden_dim = int(hidden_size)
|
||||
|
||||
# Robometer's three prediction heads + frame-pool attention.
|
||||
progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1
|
||||
self.progress_head = RobometerPredictionHead(
|
||||
hidden_dim,
|
||||
progress_output,
|
||||
dropout=dropout,
|
||||
with_sigmoid=not config.use_discrete_progress,
|
||||
)
|
||||
self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
|
||||
self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
|
||||
self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False)
|
||||
|
||||
# Match the dtype of the loaded base model so weight loading is a no-op cast.
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
self.progress_head.to(dtype=model_dtype)
|
||||
self.preference_head.to(dtype=model_dtype)
|
||||
self.success_head.to(dtype=model_dtype)
|
||||
self.frame_pool_attn.to(dtype=model_dtype)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
inputs = {
|
||||
key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"]
|
||||
for key in ROBOMETER_INPUT_KEYS
|
||||
if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch
|
||||
}
|
||||
if "input_ids" not in inputs:
|
||||
raise KeyError(
|
||||
f"Robometer batch missing pre-encoded inputs (expected "
|
||||
f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the "
|
||||
"RobometerEncoderProcessorStep ran before `compute_reward`."
|
||||
)
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
|
||||
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
progress_logits, success_logits = self._compute_rbm_logits(inputs)
|
||||
|
||||
decoded = decode_progress_outputs(
|
||||
progress_logits,
|
||||
success_logits,
|
||||
is_discrete_mode=self.config.use_discrete_progress,
|
||||
)
|
||||
values = (
|
||||
decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"]
|
||||
)
|
||||
|
||||
rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values])
|
||||
if self.config.reward_output == "success":
|
||||
rewards = (rewards > self.config.success_threshold).float()
|
||||
else:
|
||||
# Match upstream Robometer's ``extract_rewards_from_output``: per-frame
|
||||
# progress predictions are clamped to ``[0, 1]`` before being returned.
|
||||
rewards = rewards.clamp(0.0, 1.0)
|
||||
return rewards.to(self.config.device or "cpu")
|
||||
|
||||
def _compute_rbm_logits(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Run the Qwen3-VL backbone and apply Robometer's heads.
|
||||
|
||||
``inputs`` is the encoded batch produced by
|
||||
:class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well
|
||||
as Robometer-specific metadata (``prog_token_id``,
|
||||
``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``)
|
||||
— the metadata is popped here so the rest can be forwarded straight to
|
||||
the Qwen model.
|
||||
|
||||
Returns ``(progress_logits, success_logits)``. Shapes:
|
||||
|
||||
- ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
|
||||
- ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time).
|
||||
"""
|
||||
prog_token_id = inputs.pop("prog_token_id", None)
|
||||
vision_start_token_id = inputs.pop("vision_start_token_id", None)
|
||||
vision_end_token_id = inputs.pop("vision_end_token_id", None)
|
||||
video_merge_size = inputs.pop("video_merge_size", 14)
|
||||
|
||||
# Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the
|
||||
# full hidden-state tuple and take the last layer. This matches the
|
||||
# `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main).
|
||||
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
hidden_state = (
|
||||
outputs.hidden_states[-1]
|
||||
if getattr(outputs, "hidden_states", None)
|
||||
else outputs.last_hidden_state
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
if self.config.use_per_frame_progress_token:
|
||||
if prog_token_id is None:
|
||||
raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)")
|
||||
return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id)
|
||||
if self.config.use_multi_image:
|
||||
if vision_start_token_id is None or vision_end_token_id is None:
|
||||
raise KeyError(
|
||||
"`vision_start_token_id` / `vision_end_token_id` missing in batch "
|
||||
"(run RobometerEncoderProcessorStep first)"
|
||||
)
|
||||
return self._process_multi_image_frames(
|
||||
hidden_state,
|
||||
input_ids,
|
||||
start_id=vision_start_token_id,
|
||||
end_id=vision_end_token_id,
|
||||
)
|
||||
video_grid_thw = inputs.get("video_grid_thw")
|
||||
if video_grid_thw is None:
|
||||
raise ValueError("video_grid_thw is required for video-mode Robometer inference")
|
||||
if vision_start_token_id is None:
|
||||
raise KeyError("`vision_start_token_id` missing in batch")
|
||||
return self._process_video_frames(
|
||||
hidden_state,
|
||||
input_ids,
|
||||
video_grid_thw,
|
||||
start_id=vision_start_token_id,
|
||||
merge_size=video_merge_size,
|
||||
)
|
||||
|
||||
def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Apply progress + success heads to a tensor of frame embeddings."""
|
||||
progress_out = self.progress_head(frame_embeddings)
|
||||
progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out)
|
||||
success = _squeeze_last_safe(self.success_head(frame_embeddings))
|
||||
return progress, success
|
||||
|
||||
def _process_token_extraction(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
*,
|
||||
prog_token_id: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success from ``<|prog_token|>`` positions."""
|
||||
token_mask = input_ids == prog_token_id
|
||||
batch_indices, positions = token_mask.nonzero(as_tuple=True)
|
||||
if positions.numel() == 0:
|
||||
raise ValueError("`<|prog_token|>` not found in any sequence")
|
||||
|
||||
per_sample_hidden = [
|
||||
hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0])
|
||||
]
|
||||
progress_list, success_list = [], []
|
||||
for embeddings in per_sample_hidden:
|
||||
if embeddings.shape[0] == 0:
|
||||
raise ValueError("`<|prog_token|>` missing in a sequence")
|
||||
progress, success = self._apply_heads_to_hidden_states(embeddings)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
|
||||
def _process_multi_image_frames(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
*,
|
||||
start_id: int,
|
||||
end_id: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success in multi-image mode (Qwen-VL)."""
|
||||
progress_list, success_list = [], []
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
seq_ids = input_ids[batch_idx]
|
||||
seq_hidden = hidden_state[batch_idx]
|
||||
frame_embeddings = self._extract_hidden_states_from_token_pairs(
|
||||
seq_hidden, seq_ids, start_id, end_id
|
||||
)
|
||||
progress, success = self._apply_heads_to_hidden_states(frame_embeddings)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
|
||||
def _extract_hidden_states_from_token_pairs(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
start_id: int,
|
||||
end_id: int,
|
||||
) -> Tensor:
|
||||
start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0]
|
||||
end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0]
|
||||
if start_positions.numel() == 0:
|
||||
raise ValueError("`<|vision_start|>` not found in sequence")
|
||||
if start_positions.numel() != end_positions.numel():
|
||||
raise ValueError(
|
||||
f"Mismatched vision token counts: {start_positions.numel()} start vs "
|
||||
f"{end_positions.numel()} end"
|
||||
)
|
||||
|
||||
frames: list[Tensor] = []
|
||||
for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True):
|
||||
if start >= end:
|
||||
raise ValueError(f"Invalid vision token pair: start={start} end={end}")
|
||||
patch_tokens = hidden_state[start + 1 : end]
|
||||
if patch_tokens.shape[0] == 0:
|
||||
frames.append((hidden_state[start] + hidden_state[end]) / 2.0)
|
||||
continue
|
||||
|
||||
pooling = self.config.frame_pooling
|
||||
if pooling == "mean":
|
||||
frames.append(patch_tokens.mean(dim=0))
|
||||
elif pooling == "boundary":
|
||||
frames.append(patch_tokens[-1])
|
||||
else: # attention
|
||||
scores = (
|
||||
self.frame_pool_attn(patch_tokens).squeeze(-1)
|
||||
/ self.config.frame_pooling_attn_temperature
|
||||
)
|
||||
weights = torch.softmax(scores, dim=0).unsqueeze(-1)
|
||||
frames.append((weights * patch_tokens).sum(dim=0))
|
||||
|
||||
return torch.stack(frames)
|
||||
|
||||
def _process_video_frames(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
video_grid_thw: Tensor,
|
||||
*,
|
||||
start_id: int,
|
||||
merge_size: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success in video mode (Qwen-VL)."""
|
||||
progress_list, success_list = [], []
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
seq_ids = input_ids[batch_idx]
|
||||
seq_hidden = hidden_state[batch_idx]
|
||||
start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0]
|
||||
if start_positions.numel() == 0:
|
||||
raise ValueError("`<|vision_start|>` not found in sequence")
|
||||
t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist())
|
||||
tokens_per_frame = (h_dim * w_dim) // (merge_size**2)
|
||||
|
||||
cursor = start_positions[0].item()
|
||||
frame_embeddings: list[Tensor] = []
|
||||
for _ in range(t_dim):
|
||||
if self.config.average_temporal_patches:
|
||||
patch = seq_hidden[cursor : cursor + tokens_per_frame]
|
||||
frame_embeddings.append(patch.mean(dim=0))
|
||||
else:
|
||||
frame_embeddings.append(seq_hidden[cursor + tokens_per_frame])
|
||||
cursor += tokens_per_frame
|
||||
|
||||
stacked = torch.stack(frame_embeddings)
|
||||
progress, success = self._apply_heads_to_hidden_states(stacked)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
@@ -1,338 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Robometer pre/post processing pipelines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
policy_action_to_transition,
|
||||
)
|
||||
from lerobot.rewards.robometer.configuration_robometer import (
|
||||
ROBOMETER_SPECIAL_TOKENS,
|
||||
RobometerConfig,
|
||||
)
|
||||
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor
|
||||
else:
|
||||
AutoProcessor = None
|
||||
|
||||
PROGRESS_PROMPT = (
|
||||
"The task for the robot is '{task}'. Given the trajectory video, predict "
|
||||
"the task progress at each frame, how far along the robot is towards "
|
||||
"completing the task, a float between 0 and 1, where 0 is the starting "
|
||||
"state and 1 is when the task is completed. If the robot is not "
|
||||
"performing the same task, predict 0 progress."
|
||||
)
|
||||
|
||||
|
||||
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
|
||||
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
|
||||
if frames.ndim != 4:
|
||||
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
|
||||
if frames.dtype != np.uint8:
|
||||
frames = np.clip(frames, 0, 255).astype(np.uint8)
|
||||
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
|
||||
|
||||
|
||||
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
|
||||
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array."""
|
||||
if max_frames is not None:
|
||||
video = video[-max_frames:]
|
||||
if video.shape[1] in (1, 3):
|
||||
video = video.permute(0, 2, 3, 1)
|
||||
elif video.shape[-1] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
array = video.detach().cpu().numpy()
|
||||
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
|
||||
array = array * 255.0
|
||||
return np.clip(array, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
|
||||
if task is None:
|
||||
task = default
|
||||
if task is None:
|
||||
raise KeyError("Robometer expected a task description in complementary data")
|
||||
if isinstance(task, str):
|
||||
return [task] * batch_size
|
||||
if isinstance(task, tuple):
|
||||
task = list(task)
|
||||
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
|
||||
raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}")
|
||||
if len(task) == 1 and batch_size > 1:
|
||||
return task * batch_size
|
||||
if len(task) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robometer_encoder")
|
||||
class RobometerEncoderProcessorStep(ProcessorStep):
|
||||
"""Encode raw frames + task into Qwen-VL tensors for the Robometer model.
|
||||
|
||||
Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and
|
||||
registers Robometer's special tokens on the tokenizer. The matching
|
||||
embedding resize happens model-side in
|
||||
:meth:`RobometerRewardModel.__init__`.
|
||||
|
||||
At call time the step reads:
|
||||
|
||||
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
|
||||
- ``complementary_data[task_key]``: a string or list of strings.
|
||||
|
||||
and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}<name>"]`` for:
|
||||
|
||||
- the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``,
|
||||
``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ...
|
||||
- Robometer-specific token ids consumed by the model heads:
|
||||
``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``,
|
||||
``video_merge_size``.
|
||||
"""
|
||||
|
||||
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 8
|
||||
use_multi_image: bool = True
|
||||
use_per_frame_progress_token: bool = True
|
||||
max_length: int = 1024
|
||||
|
||||
_processor: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
require_package("transformers", extra="robometer")
|
||||
require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils")
|
||||
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
self.base_model_id,
|
||||
trust_remote_code=True,
|
||||
do_sample_frames=False,
|
||||
padding_side="right",
|
||||
)
|
||||
|
||||
# Register Robometer's special tokens on the tokenizer. The matching
|
||||
# embedding resize happens model-side in `RobometerRewardModel.__init__`.
|
||||
tokenizer = self._processor.tokenizer
|
||||
# Qwen tokenizers may not define a pad token, but batched prompts/videos
|
||||
# require padding, so reuse EOS as the padding token.
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
for token in ROBOMETER_SPECIAL_TOKENS:
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [token]})
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("RobometerEncoderProcessorStep requires an observation dict")
|
||||
|
||||
if self.image_key not in observation:
|
||||
raise KeyError(f"Robometer expected image key {self.image_key!r} in observation")
|
||||
|
||||
frames = observation[self.image_key]
|
||||
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
|
||||
if tensor.ndim == 4:
|
||||
tensor = tensor.unsqueeze(1)
|
||||
elif tensor.ndim != 5:
|
||||
raise ValueError(
|
||||
f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
batch_size = tensor.shape[0]
|
||||
tasks = _expand_tasks(
|
||||
complementary.get(self.task_key, self.default_task),
|
||||
batch_size=batch_size,
|
||||
default=self.default_task,
|
||||
)
|
||||
|
||||
samples = [
|
||||
(_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size)
|
||||
]
|
||||
encoded = self.encode_samples(samples)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for key, value in encoded.items():
|
||||
new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]:
|
||||
"""Run the Qwen-VL processor on a list of ``(frames, task)`` samples."""
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
conversations = [self._build_conversation(frames, task) for frames, task in samples]
|
||||
|
||||
texts = [
|
||||
self._processor.apply_chat_template(
|
||||
msg,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
add_vision_id=True,
|
||||
enable_thinking=False,
|
||||
fps=1,
|
||||
)
|
||||
for msg in conversations
|
||||
]
|
||||
|
||||
process_kwargs: dict[str, Any] = {
|
||||
"return_video_kwargs": True,
|
||||
"return_video_metadata": True,
|
||||
}
|
||||
image_processor = getattr(self._processor, "image_processor", None)
|
||||
if image_processor is not None and hasattr(image_processor, "patch_size"):
|
||||
process_kwargs["image_patch_size"] = image_processor.patch_size
|
||||
|
||||
image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs)
|
||||
|
||||
videos: list[Any] | None = None
|
||||
video_metadatas: list[Any] | None = None
|
||||
if video_inputs:
|
||||
if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2:
|
||||
videos_seq, metadatas_seq = zip(*video_inputs, strict=False)
|
||||
videos = list(videos_seq)
|
||||
video_metadatas = list(metadatas_seq)
|
||||
else:
|
||||
videos = list(video_inputs)
|
||||
|
||||
processor_kwargs: dict[str, Any] = {
|
||||
"text": texts,
|
||||
"images": image_inputs,
|
||||
"padding": True,
|
||||
"truncation": False,
|
||||
"max_length": self.max_length,
|
||||
"return_tensors": "pt",
|
||||
"do_resize": False,
|
||||
}
|
||||
if videos is not None:
|
||||
processor_kwargs["videos"] = videos
|
||||
if video_metadatas is not None:
|
||||
processor_kwargs["video_metadata"] = video_metadatas
|
||||
if video_kwargs:
|
||||
processor_kwargs.update(video_kwargs)
|
||||
|
||||
encoded = self._processor(**processor_kwargs)
|
||||
|
||||
# Write Robometer-specific token ids and the video patch merge size into
|
||||
# the encoded batch so `RobometerRewardModel` doesn't need its own
|
||||
# tokenizer at inference (EO1-style separation: the processor owns the
|
||||
# tokenizer, the model owns the backbone and heads).
|
||||
tokenizer = self._processor.tokenizer
|
||||
encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>")
|
||||
encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
||||
encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>")
|
||||
video_processor = getattr(self._processor, "video_processor", None)
|
||||
encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14))
|
||||
return encoded
|
||||
|
||||
def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]:
|
||||
pil_frames = _frames_to_pil(frames)
|
||||
prompt = PROGRESS_PROMPT.format(task=task)
|
||||
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
|
||||
|
||||
if self.use_multi_image:
|
||||
for image in pil_frames:
|
||||
content.append({"type": "image", "image": image})
|
||||
if self.use_per_frame_progress_token:
|
||||
content.append({"type": "text", "text": "<|prog_token|>"})
|
||||
else:
|
||||
content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"base_model_id": self.base_model_id,
|
||||
"image_key": self.image_key,
|
||||
"task_key": self.task_key,
|
||||
"default_task": self.default_task,
|
||||
"max_frames": self.max_frames,
|
||||
"use_multi_image": self.use_multi_image,
|
||||
"use_per_frame_progress_token": self.use_per_frame_progress_token,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
|
||||
def make_robometer_pre_post_processors(
|
||||
config: RobometerConfig,
|
||||
dataset_stats: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
|
||||
|
||||
The preprocessor adds a batch dimension if needed, runs Robometer's
|
||||
encoder, and moves everything to the configured device. The
|
||||
postprocessor is the identity since Robometer outputs a single reward
|
||||
tensor.
|
||||
"""
|
||||
del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor.
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RobometerEncoderProcessorStep(
|
||||
base_model_id=config.base_model_id,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=config.max_frames,
|
||||
use_multi_image=config.use_multi_image,
|
||||
use_per_frame_progress_token=config.use_per_frame_progress_token,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -18,8 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from ..robot import Robot
|
||||
@@ -28,7 +27,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
class BiOpenArmFollower(Robot):
|
||||
"""
|
||||
Bimanual OpenArm Follower Arms
|
||||
"""
|
||||
@@ -40,17 +39,15 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
|
||||
# will only open the cameras assigned to it. Per-arm cameras are used
|
||||
# as fallback when top-level cameras are empty.
|
||||
if config.cameras:
|
||||
left_cameras = config.cameras
|
||||
right_cameras = {}
|
||||
else:
|
||||
left_cameras = config.left_arm_config.cameras
|
||||
right_cameras = config.right_arm_config.cameras
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
@@ -59,7 +56,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_arm_cameras,
|
||||
cameras=left_cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
@@ -78,7 +75,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
cameras=right_cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
@@ -98,19 +95,22 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
|
||||
# "right_wrist"), so we merge them directly — unlike motors which need the
|
||||
# left_/right_ prefix to disambiguate identical per-arm joint names.
|
||||
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -120,6 +120,27 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
@@ -127,15 +148,21 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict: RobotObservation = {}
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||
for key, value in self.left_arm.get_observation().items():
|
||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||
# Camera keys that should NOT get the arm prefix (they already have unique names)
|
||||
left_cam_keys = set(self.left_arm.cameras.keys())
|
||||
right_cam_keys = set(self.right_arm.cameras.keys())
|
||||
|
||||
# Add "right_" prefix
|
||||
for key, value in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{key}"] = value
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
right_obs = self.right_arm.get_observation()
|
||||
for key, value in right_obs.items():
|
||||
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
|
||||
|
||||
left_obs = self.left_arm.get_observation()
|
||||
for key, value in left_obs.items():
|
||||
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -162,4 +189,9 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -32,7 +32,5 @@ class BiOpenArmFollowerConfig(RobotConfig):
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
# Top-level cameras shared across both arms.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -18,8 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
|
||||
from ..robot import Robot
|
||||
@@ -28,7 +27,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
class BiRebotB601Follower(Robot):
|
||||
"""Bimanual Seeed Studio reBot B601-DM follower.
|
||||
|
||||
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
|
||||
@@ -42,18 +41,6 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = RebotB601FollowerRobotConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
@@ -62,7 +49,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
dm_serial_baud=config.left_arm_config.dm_serial_baud,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_arm_cameras,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
motor_can_ids=config.left_arm_config.motor_can_ids,
|
||||
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
||||
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
||||
@@ -99,12 +86,10 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -114,13 +99,32 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict: RobotObservation = {}
|
||||
for k, v in self.left_arm.get_observation().items():
|
||||
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{k}"] = v
|
||||
obs_dict = {}
|
||||
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
|
||||
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -139,3 +143,8 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
**{f"left_{k}": v for k, v in sent_action_left.items()},
|
||||
**{f"right_{k}": v for k, v in sent_action_right.items()},
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -14,9 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import RobotConfig
|
||||
from ..rebot_b601_follower import RebotB601FollowerConfig
|
||||
@@ -29,8 +27,3 @@ class BiRebotB601FollowerConfig(RobotConfig):
|
||||
|
||||
left_arm_config: RebotB601FollowerConfig
|
||||
right_arm_config: RebotB601FollowerConfig
|
||||
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -18,8 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..so_follower import SOFollower, SOFollowerRobotConfig
|
||||
@@ -28,7 +27,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiSOFollower(BimanualMixin, Robot):
|
||||
class BiSOFollower(Robot):
|
||||
"""
|
||||
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||
"""
|
||||
@@ -40,18 +39,6 @@ class BiSOFollower(BimanualMixin, Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = SOFollowerRobotConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
@@ -59,7 +46,7 @@ class BiSOFollower(BimanualMixin, Robot):
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
use_degrees=config.left_arm_config.use_degrees,
|
||||
cameras=left_arm_cameras,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
)
|
||||
|
||||
right_arm_config = SOFollowerRobotConfig(
|
||||
@@ -90,12 +77,13 @@ class BiSOFollower(BimanualMixin, Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -105,21 +93,42 @@ class BiSOFollower(BimanualMixin, Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict: RobotObservation = {}
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||
for key, value in self.left_arm.get_observation().items():
|
||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
for key, value in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{key}"] = value
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -142,3 +151,8 @@ class BiSOFollower(BimanualMixin, Robot):
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -14,9 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import RobotConfig
|
||||
from ..so_follower import SOFollowerConfig
|
||||
@@ -29,8 +27,3 @@ class BiSOFollowerConfig(RobotConfig):
|
||||
|
||||
left_arm_config: SOFollowerConfig
|
||||
right_arm_config: SOFollowerConfig
|
||||
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -68,12 +68,9 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
features: dict[str, tuple] = {}
|
||||
for cam in self.cameras:
|
||||
features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3)
|
||||
if getattr(self.cameras[cam], "use_depth", False):
|
||||
features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width, 1)
|
||||
return features
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -193,12 +190,6 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
if getattr(cam, "use_depth", False):
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user