mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 11:17:02 +00:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d99d4eda0f | |||
| 84961c5591 | |||
| 2e53ccf872 | |||
| 73782447f2 | |||
| 2d7a42011a | |||
| b06ad40888 | |||
| b3d74f80f0 | |||
| 552b4c3563 | |||
| 8bf6056d14 | |||
| da92db8fc0 | |||
| 2b0834bcb8 | |||
| 9449e68725 | |||
| 287c823f13 | |||
| 2b83956eb5 | |||
| 7309790d56 | |||
| 58ccc01508 | |||
| 38327fdc84 | |||
| 9555efc02c | |||
| d576c59afb | |||
| 2201401c99 | |||
| 64773e7b22 | |||
| 8515d456be | |||
| 30790de178 | |||
| cec8ee0be6 | |||
| 02b315ab6a | |||
| 234c768dfb | |||
| 0e9bd9e6fb | |||
| 87242cfced | |||
| 1edc83a0ef | |||
| 6fbcf67249 | |||
| 41166b39fb | |||
| 79c6821407 | |||
| 507083249f | |||
| bd22407d93 | |||
| 49755a3d9e | |||
| 09808183ca | |||
| 2e9cd87bbd | |||
| d1b1c5c8cf | |||
| 741c2d0a39 | |||
| 19fe315971 | |||
| 906b585826 | |||
| b8ad81bf39 | |||
| 24017e960c | |||
| e86f5af5bf | |||
| 5c98e80430 | |||
| f65f3f7a4a |
@@ -167,9 +167,9 @@ jobs:
|
|||||||
|
|
||||||
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
|
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
|
||||||
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
|
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
|
||||||
# immediately runs eval inside the training loop (eval_freq=1, 1 episode).
|
# immediately runs eval inside the training loop (env_eval_freq=1, 1 episode).
|
||||||
# Tests the full train→eval-within-training pipeline end-to-end.
|
# Tests the full train→eval-within-training pipeline end-to-end.
|
||||||
- name: Run Libero train+eval smoke (1 step, eval_freq=1)
|
- name: Run Libero train+eval smoke (1 step, env_eval_freq=1)
|
||||||
if: env.HF_USER_TOKEN != ''
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
docker run --name libero-train-smoke --gpus all \
|
docker run --name libero-train-smoke --gpus all \
|
||||||
@@ -196,7 +196,7 @@ jobs:
|
|||||||
--output_dir=/tmp/train-smoke \
|
--output_dir=/tmp/train-smoke \
|
||||||
--steps=1 \
|
--steps=1 \
|
||||||
--batch_size=1 \
|
--batch_size=1 \
|
||||||
--eval_freq=1 \
|
--env_eval_freq=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ repos:
|
|||||||
name: Format Markdown with Prettier
|
name: Format Markdown with Prettier
|
||||||
types_or: [markdown, mdx]
|
types_or: [markdown, mdx]
|
||||||
args: [--prose-wrap=preserve]
|
args: [--prose-wrap=preserve]
|
||||||
|
# Jinja2 model-card templates use a .md extension but contain {% ... %} /
|
||||||
|
# {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops).
|
||||||
|
exclude: ^src/lerobot/templates/.*\.md$
|
||||||
|
|
||||||
##### Security #####
|
##### Security #####
|
||||||
- repo: https://github.com/gitleaks/gitleaks
|
- repo: https://github.com/gitleaks/gitleaks
|
||||||
|
|||||||
@@ -1,404 +0,0 @@
|
|||||||
# LeRobot Documentation Redesign — Proposal
|
|
||||||
|
|
||||||
**Status:** Draft for maintainer review · **Author:** Dev Rel · **Date:** June 2026
|
|
||||||
**Scope:** Information architecture + UX redesign of https://huggingface.co/docs/lerobot (source: `docs/source/`)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Executive summary
|
|
||||||
|
|
||||||
**The problem, in five points:**
|
|
||||||
|
|
||||||
1. **No guided beginner path.** "Get started" is three pages (landing, install, CLI cheat-sheet). After that, a newcomer lands in a "Tutorials" section that mixes the core 652-line SO-101 walkthrough (`il_robots.mdx`) with RL (950-line `hilserl.mdx`), PEFT, multi-GPU, and rename-map content — four different audiences in one flat list.
|
|
||||||
2. **Zero decision guidance.** 11 robot pages, 12 policy pages, 10 benchmark pages — all flat alphabetical-ish lists. Nothing answers "which robot should I buy?", "which policy after ACT?", "which benchmark fits my claim?".
|
|
||||||
3. **The best beginner content isn't on the docs site.** `AGENT_GUIDE.md` (repo root) contains the most complete procedural guidance we have — data-collection golden rules, policy selection by VRAM tier, training-duration heuristics, eval success-rate targets — and none of it is published.
|
|
||||||
4. **Structural debt.** 14 orphan `policy_*_README.md` files sit in `docs/source/` unreferenced by the toctree (verbatim hand-copies of `src/lerobot/policies/*/README.md`); 6 of 17 CLI commands are completely undocumented; 1 of 15 teleoperators has a doc page; there is no troubleshooting page and no `_redirects.yml`.
|
|
||||||
5. **Pages do too many jobs.** Tutorial, how-to, and reference content is interleaved (e.g., `lerobot-dataset-v3.mdx` is simultaneously a format spec, a Python API guide, and a migration guide), which makes every page long and none of them skimmable.
|
|
||||||
|
|
||||||
**The proposal, in five points:**
|
|
||||||
|
|
||||||
1. **A numbered "Learn: your first robot" course** (FastAPI-style): 11 small CLI-only pages taking an SO-101 owner from install to a trained, evaluated ACT policy. Each page ends with a checkpoint and a "Next" link. A one-page no-hardware quickstart runs in parallel.
|
|
||||||
2. **A sidebar that reads top-to-bottom as increasing expertise**, organized into 17 sections: Get started → task-oriented how-tos (Collect data / Train / Deploy & evaluate) → hardware catalogs → policy catalogs → concepts → RL → extending → reference → community.
|
|
||||||
3. **Every catalog section opens with a comparison/decision page**: Choosing a robot, Teleoperator overview, Choosing a policy, Reward models overview, Choosing a benchmark.
|
|
||||||
4. **Publish the missing content**: AGENT_GUIDE.md's tips become real docs pages; a full CLI reference covers all 17 commands; a symptom-organized troubleshooting page aggregates fixes currently buried across five pages.
|
|
||||||
5. **Minimal URL breakage**: existing slugs are kept wherever possible. Only **2 redirects** are needed (`cheat-sheet → cli_reference`, `il_robots → learn/index`); everything else is a toctree reposition.
|
|
||||||
|
|
||||||
**Effort estimate:** ~4–5 person-weeks across 5 phases; phases 3 and 5 are parallelizable and community-friendly (see §9).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Current-state audit
|
|
||||||
|
|
||||||
### 2.1 By the numbers
|
|
||||||
|
|
||||||
| Metric | Value |
|
|
||||||
|---|---|
|
|
||||||
| Files in `docs/source/` (excl. `_toctree.yml`) | 90 |
|
|
||||||
| Pages in the toctree | 76 |
|
|
||||||
| Orphan files (on disk, not in toctree) | 14 (`policy_*_README.md`) |
|
|
||||||
| Top-level sidebar sections | 15, all flat (no nesting) |
|
|
||||||
| Longest pages | `hilserl.mdx` ~950 lines · `il_robots.mdx` ~650 · `so100.mdx` ~640 |
|
|
||||||
| CLI commands in `pyproject.toml [project.scripts]` | 17 — **6 with zero doc presence** (`lerobot-find-cameras`, `lerobot-setup-can`, `lerobot-find-joint-limits`, `lerobot-train-tokenizer`, `lerobot-imgtransform-viz`, `lerobot-info`) |
|
|
||||||
| Teleoperators in `src/lerobot/teleoperators/` | 15 — **1 documented** (`phone_teleop.mdx`); keyboard, gamepad, leader arms covered only implicitly inside robot/tutorial pages |
|
|
||||||
| Camera backends | 4 — `zmq` undocumented, `reachy2_camera` implicit |
|
|
||||||
| Policies in `src/lerobot/policies/` | 19 — `diffusion`, `tdmpc`, `vqbet` have orphan READMEs but **no doc page**; `sac`, `gaussian_actor` undocumented |
|
|
||||||
| Redirect infrastructure | None (`_redirects.yml` does not exist) |
|
|
||||||
|
|
||||||
### 2.2 Prioritized UX problems
|
|
||||||
|
|
||||||
**Critical (blocks the main funnel):**
|
|
||||||
|
|
||||||
| # | Problem | Evidence |
|
|
||||||
|---|---|---|
|
|
||||||
| C1 | No linear beginner path: install → ??? | "Get started" = `index.mdx` (brand copy) + `installation.mdx` + `cheat-sheet.mdx`. The actual path lives inside `il_robots.mdx`, filed under "Tutorials" with 9 unrelated pages. |
|
|
||||||
| C2 | No "which robot do I buy?" guidance | 11 robot pages, no comparison. SO-101 vs SO-100 relationship is never stated; SO-100's page is *longer* than SO-101's. |
|
|
||||||
| C3 | `il_robots.mdx` monolith | One 652-line page covers calibrate + teleoperate + cameras + record + train + evaluate, doubling every step in CLI *and* Python, all SO-101-specific despite the generic title. |
|
|
||||||
| C4 | Best procedural content unpublished | `AGENT_GUIDE.md` §5 (data tips), §6 (policy choice), §7 (training duration), §8 (eval targets) exist only in the repo root. |
|
|
||||||
| C5 | No hardware-free on-ramp | No "train on a Hub dataset in 15 minutes" page; `notebooks.mdx` is 30 lines of links. |
|
|
||||||
|
|
||||||
**High (breaks common tasks):**
|
|
||||||
|
|
||||||
| # | Problem | Evidence |
|
|
||||||
|---|---|---|
|
|
||||||
| H1 | Training guidance scattered | Spread across `act.mdx`, `hardware_guide.mdx`, `multi_gpu_training.mdx`, `il_robots.mdx` — no single "how to train" page. |
|
|
||||||
| H2 | No policy comparison/decision page | 12 policy pages; ACT says "recommended first" but nothing maps VRAM/task/data-size → policy. |
|
|
||||||
| H3 | Recording how-to fused with format spec | `il_robots.mdx` (procedure) vs `lerobot-dataset-v3.mdx` (spec + Python API + migration) — neither is a clean reference or a clean guide. |
|
|
||||||
| H4 | No benchmark selection guidance | 10 benchmark pages in a flat list; "Adding a New Benchmark" (contributor content) sits *first*. |
|
|
||||||
| H5 | No troubleshooting page | Fixes are buried in `installation.mdx`, `il_robots.mdx`, `lerobot-dataset-v3.mdx` "Common Issues", and AGENT_GUIDE §5.8. |
|
|
||||||
| H6 | Teleoperators near-invisible | 15 implementations, 1 page. Keyboard/gamepad (`--teleop.type=keyboard|gamepad`) have no docs at all. |
|
|
||||||
|
|
||||||
**Medium / low (debt and polish):**
|
|
||||||
|
|
||||||
| # | Problem | Evidence |
|
|
||||||
|---|---|---|
|
|
||||||
| M1 | 14 orphan `policy_*_README.md` in `docs/source/` | Verbatim copies of `src/lerobot/policies/*/README.md`; not in toctree, not built; no sync script exists (verified by repo-wide grep) — hand-copied and already drifting. |
|
|
||||||
| M2 | 6/17 CLI commands undocumented | See §2.1. `cheat-sheet.mdx` covers only the main workflow commands. |
|
|
||||||
| M3 | Mixed-audience sections | "Tutorials" holds beginner + RL-researcher + contributor content; "Benchmarks" holds user + contributor content. |
|
|
||||||
| M4 | Naming inconsistencies | "SO-101"/"SO101"/"so101" in prose; `<Tip>` vs `> [!NOTE]` admonitions mixed; hyphen vs underscore slugs (`lerobot-dataset-v3` vs `multi_gpu_training`). |
|
|
||||||
| M5 | Section opener pages missing | "Robot Processors" starts at "Introduction to Robot Processors" (good) but Robots/Policies/Benchmarks/Teleoperators start with an arbitrary instance page. |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. What great docs do (and what we borrow)
|
|
||||||
|
|
||||||
**Diátaxis** (diataxis.fr — adopted by Canonical, Cloudflare, Gatsby): documentation serves four distinct needs that should never share a page — *tutorials* (learning by doing), *how-to guides* (task recipes), *reference* (lookup facts), *explanation* (concepts). LeRobot's biggest pages fail precisely because they fuse three of these. → **We borrow:** the four-way separation as the organizing principle of the new sidebar; splits like `lerobot-dataset-v3` (spec) vs `using_lerobot_dataset` (how-to).
|
|
||||||
|
|
||||||
**FastAPI**: famous for a numbered, linear tutorial where each chapter is small, runnable, and builds on the previous one; advanced topics live in a separate section so the learn-path is never interrupted. → **We borrow:** the numbered `learn/` course with one outcome per page and a "Next:" footer; advanced content (RL, multi-GPU) moved out of the beginner's way.
|
|
||||||
|
|
||||||
**Ultralytics YOLO**: handles a many-models × many-tasks × many-modes matrix with *orthogonal* navigation axes plus comparison pages, instead of forcing one hierarchy. → **We borrow:** Robots, Teleoperators, Policies, and Benchmarks as separate catalog sections, each opening with a comparison table ("Choosing a…" pages).
|
|
||||||
|
|
||||||
**HF Transformers** (same doc-builder tooling — everything it does, we can do): card-grid landing page, `<hfoptions>` tabs with **persistent selection** (pick PyTorch once, every code block follows), nested collapsible toctree sections, subdirectory slugs (`model_doc/bert`), `_redirects.yml`. → **We borrow:** card-based `index`, site-wide `<hfoptions id="robot">` / `id="os"` / `id="train_env"` tabs, a collapsible nested course section, and the redirects file.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. Proposed information architecture
|
|
||||||
|
|
||||||
**17 sections, ~101 toctree pages** (today: 15 sections, 76 pages — growth is almost entirely the 11 deliberately-small course pages; the 14 orphans are absorbed into 2 pages). The sidebar reads top-to-bottom as increasing expertise. Slugs in parentheses; annotations: **[NEW]** to be written · **[KEPT]** same file/slug · **[MOVED]** same file, new section · **[SPLIT from X]** · **[MERGED from X+Y]**.
|
|
||||||
|
|
||||||
```
|
|
||||||
Get started
|
|
||||||
├─ LeRobot (index) ........................................ [KEPT — rewritten as card-grid landing]
|
|
||||||
├─ Installation (installation) ............................ [KEPT — full install matrix, reference]
|
|
||||||
├─ Learn: your first robot — SO-101 course (nested, collapsible)
|
|
||||||
│ ├─ Welcome: what you'll build (learn/index) ............ [NEW]
|
|
||||||
│ ├─ 1. Install LeRobot (learn/install) .................. [SPLIT from installation — happy path only]
|
|
||||||
│ ├─ 2. Assemble your SO-101 (learn/assemble) ............ [SPLIT from so101]
|
|
||||||
│ ├─ 3. Set up the motors (learn/motors) ................. [SPLIT from so101]
|
|
||||||
│ ├─ 4. Calibrate (learn/calibrate) ...................... [SPLIT from so101 + il_robots]
|
|
||||||
│ ├─ 5. Teleoperate your arms (learn/teleoperate) ........ [SPLIT from il_robots]
|
|
||||||
│ ├─ 6. Connect your cameras (learn/cameras) ............. [SPLIT from cameras + il_robots]
|
|
||||||
│ ├─ 7. Record a dataset (learn/record) .................. [SPLIT from il_robots + AGENT_GUIDE §5.7]
|
|
||||||
│ ├─ 8. Train your first policy (learn/train) ............ [SPLIT from il_robots + AGENT_GUIDE §7]
|
|
||||||
│ ├─ 9. Evaluate your policy (learn/evaluate) ............ [SPLIT from il_robots + AGENT_GUIDE §8]
|
|
||||||
│ └─ 10. Next steps (learn/next_steps) ................... [NEW]
|
|
||||||
├─ Quickstart without a robot (quickstart_no_robot) ....... [NEW — AGENT_GUIDE §3 Path B + notebooks]
|
|
||||||
└─ Troubleshooting (troubleshooting) ...................... [NEW — aggregated from 5 pages + AGENT_GUIDE §5.8]
|
|
||||||
|
|
||||||
Collect data
|
|
||||||
├─ Record a dataset (record_dataset) ...................... [SPLIT from il_robots — robot-agnostic, CLI+Python]
|
|
||||||
├─ Get high-quality data (data_collection_tips) ........... [NEW — AGENT_GUIDE §5]
|
|
||||||
├─ Visualize & replay episodes (visualize_replay) ......... [SPLIT from il_robots + using_dataset_tools]
|
|
||||||
├─ Edit datasets (using_dataset_tools) .................... [KEPT — retitled]
|
|
||||||
├─ Port large datasets to v3 (porting_datasets_v3) ........ [KEPT]
|
|
||||||
└─ Add language instructions (language_and_recipes) ....... [KEPT — retitled]
|
|
||||||
|
|
||||||
Train policies
|
|
||||||
├─ Train a policy (train_policy) .......................... [SPLIT from il_robots + AGENT_GUIDE §7; local/Colab/HF-Jobs tabs]
|
|
||||||
├─ Compute requirements (hardware_guide) .................. [KEPT — retitled]
|
|
||||||
├─ Multi-GPU training (multi_gpu_training) ................ [MOVED]
|
|
||||||
├─ Fine-tune with PEFT / LoRA (peft_training) ............. [MOVED]
|
|
||||||
├─ PyTorch accelerators (torch_accelerators) .............. [KEPT]
|
|
||||||
└─ Rename map & empty cameras (rename_map) ................ [MOVED]
|
|
||||||
|
|
||||||
Deploy & evaluate
|
|
||||||
├─ Evaluate a policy on your robot (evaluate_policy) ...... [SPLIT from il_robots + AGENT_GUIDE §8]
|
|
||||||
├─ Deploy a policy — lerobot-rollout (inference) .......... [KEPT]
|
|
||||||
├─ Async inference (async) ................................ [KEPT]
|
|
||||||
└─ Real-time chunking — RTC (rtc) ......................... [KEPT — absorbs policy_rtc_README.md]
|
|
||||||
|
|
||||||
Robots
|
|
||||||
├─ Choosing a robot (choose_a_robot) ...................... [NEW — comparison table + decision guidance]
|
|
||||||
├─ SO-101 (so101) ......................................... [SPLIT — assembly/motors/calibration → course; page becomes spec + sourcing + links]
|
|
||||||
├─ SO-100 — previous generation (so100) ................... [KEPT + legacy banner]
|
|
||||||
├─ Koch v1.1 (koch) · LeKiwi (lekiwi) · Hope Jr (hope_jr) · Reachy 2 (reachy2)
|
|
||||||
│ Unitree G1 (unitree_g1) · Earth Rover Mini (earthrover_mini_plus)
|
|
||||||
│ OMX (omx) · OpenArm (openarm) · reBot B601-DM (rebot_b601) ... [all KEPT, normalized to robot template]
|
|
||||||
|
|
||||||
Teleoperators
|
|
||||||
├─ Teleoperator overview (teleoperators_overview) ......... [NEW — pairing matrix for all 15; keyboard/gamepad usage]
|
|
||||||
└─ Phone teleoperation (phone_teleop) ..................... [KEPT]
|
|
||||||
|
|
||||||
Cameras & motors
|
|
||||||
├─ Cameras (cameras) ...................................... [KEPT — expanded: ZMQ + Reachy 2 backends]
|
|
||||||
├─ Feetech motors & firmware (feetech) .................... [MOVED]
|
|
||||||
└─ Damiao motors & CAN bus (damiao) ....................... [MOVED — documents lerobot-setup-can]
|
|
||||||
|
|
||||||
Policies
|
|
||||||
├─ Choosing a policy (choose_a_policy) .................... [NEW — AGENT_GUIDE §6: VRAM tiers, decision rules]
|
|
||||||
├─ ACT (act) .............................................. [KEPT — absorbs policy_act_README.md]
|
|
||||||
├─ Diffusion Policy (diffusion) ........................... [NEW — from policy_diffusion_README.md; no page exists today]
|
|
||||||
├─ SmolVLA (smolvla) · π₀ (pi0) · π₀-FAST (pi0fast) · π₀.₅ (pi05)
|
|
||||||
│ GR00T N1.5 (groot) · MolmoAct2 (molmoact2) · VLA-JEPA (vla_jepa) · EO-1 (eo1)
|
|
||||||
│ X-VLA (xvla) · Multitask DiT (multi_task_dit) · WALL-OSS (walloss) ... [all KEPT, absorb matching READMEs, normalized to policy template]
|
|
||||||
└─ Legacy policies — VQ-BeT, TDMPC (legacy_policies) ...... [NEW — MERGED from policy_vqbet_README + policy_tdmpc_README]
|
|
||||||
|
|
||||||
Reward models
|
|
||||||
├─ Reward models overview (reward_models_overview) ........ [NEW]
|
|
||||||
└─ SARM (sarm) · ROBOMETER (robometer) · TOPReward (topreward) ... [KEPT]
|
|
||||||
|
|
||||||
Datasets in depth
|
|
||||||
├─ LeRobotDataset format — v3 spec (lerobot-dataset-v3) ... [SPLIT — keeps slug; format/layout/migration only]
|
|
||||||
├─ Load & stream datasets in Python (using_lerobot_dataset) [SPLIT from lerobot-dataset-v3 — Python API, streaming, transforms]
|
|
||||||
├─ Tool-calling columns (tools) ........................... [KEPT — retitled]
|
|
||||||
├─ Video encoding parameters (video_encoding_parameters) .. [KEPT]
|
|
||||||
└─ Streaming video encoding (streaming_video_encoding) .... [KEPT]
|
|
||||||
|
|
||||||
Simulation
|
|
||||||
├─ Environments from the Hub (envhub) ..................... [KEPT — section opener]
|
|
||||||
├─ LeIsaac: control & train in Isaac Sim (envhub_leisaac) . [KEPT]
|
|
||||||
└─ NVIDIA IsaacLab Arena (envhub_isaaclab_arena) .......... [MOVED from Benchmarks]
|
|
||||||
|
|
||||||
Benchmarks
|
|
||||||
├─ Choosing a benchmark (choose_a_benchmark) .............. [NEW — comparison + match-benchmark-to-claim guidance]
|
|
||||||
└─ LIBERO (libero) · LIBERO-plus (libero_plus) · Meta-World (metaworld)
|
|
||||||
RoboTwin 2.0 (robotwin) · RoboCasa365 (robocasa) · RoboCerebra (robocerebra)
|
|
||||||
RoboMME (robomme) · VLABench (vlabench) ................ [all KEPT]
|
|
||||||
|
|
||||||
Processors
|
|
||||||
├─ What are processors? (introduction_processors) ......... [KEPT — entry-level rewrite + diagram]
|
|
||||||
├─ Processors for robots & teleops (processors_robots_teleop) [KEPT]
|
|
||||||
├─ Environment processors (env_processor) ................. [KEPT]
|
|
||||||
├─ Action representations (action_representations) ........ [KEPT]
|
|
||||||
└─ Debug a processor pipeline (debug_processor_pipeline) .. [KEPT]
|
|
||||||
|
|
||||||
Reinforcement learning
|
|
||||||
├─ Train with RL on a real robot — HIL-SERL (hilserl) ..... [MOVED — split candidate, see §10 Q7]
|
|
||||||
├─ Train RL in simulation (hilserl_sim) ................... [MOVED]
|
|
||||||
└─ Human-in-the-loop data collection (hil_data_collection) [MOVED]
|
|
||||||
|
|
||||||
Extending LeRobot
|
|
||||||
├─ Add a policy (bring_your_own_policies) ................. [MOVED]
|
|
||||||
├─ Add a robot (integrate_hardware) ....................... [MOVED — retitled]
|
|
||||||
├─ Write your own processor (implement_your_own_processor) [MOVED]
|
|
||||||
└─ Add a benchmark (adding_benchmarks) .................... [MOVED]
|
|
||||||
|
|
||||||
Reference & resources
|
|
||||||
├─ CLI reference (cli_reference) .......................... [MERGED from cheat-sheet + NEW content for 6 undocumented commands]
|
|
||||||
├─ LeLab: browser GUI (lelab) ............................. [MOVED — placement is open question Q5]
|
|
||||||
├─ Notebooks (notebooks) .................................. [MOVED]
|
|
||||||
└─ Backward compatibility (backwardcomp) .................. [MOVED]
|
|
||||||
|
|
||||||
Community
|
|
||||||
└─ Contribute to LeRobot (contributing) ................... [KEPT]
|
|
||||||
```
|
|
||||||
|
|
||||||
**`_redirects.yml`** (complete file — only two entries):
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
cheat-sheet: cli_reference
|
|
||||||
il_robots: learn/index
|
|
||||||
```
|
|
||||||
|
|
||||||
Everything else keeps its slug; toctree moves and retitles don't change URLs in doc-builder.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. The beginner course, page by page
|
|
||||||
|
|
||||||
**Template for every course page:** goal sentence → "You need" list → numbered steps, each ending with a verifiable checkpoint ("you should see…") → mini-troubleshooting (≤3 most common failures) → "Next:" footer link. **CLI-only — Python appears nowhere in the course.** OS-specific commands use `<hfoptions id="os">`. Budget ≤200 lines per page (assembly exempt: media-heavy).
|
|
||||||
|
|
||||||
| # | Slug | Title | Goal — the user can… | Content source | Length | Next |
|
|
||||||
|---|---|---|---|---|---|---|
|
|
||||||
| 0 | `learn/index` | Welcome: what you'll build | See the outcome (a trained ACT policy moving a real SO-101), the shopping list (~$200 kit, 2 cameras, GPU or Colab), time budget (~1 weekend) | NEW; framing from AGENT_GUIDE §2–3 | ~120 | 1 |
|
|
||||||
| 1 | `learn/install` | Install LeRobot | Working env via one happy path (`pip install 'lerobot[feetech]'`), verified with `lerobot-info` | `installation.mdx` steps 1–3, trimmed to one path | ~100 | 2 |
|
|
||||||
| 2 | `learn/assemble` | Assemble your SO-101 | Source/print parts, assemble leader + follower | `so101.mdx` sourcing + assembly sections (videos carry the load) | ~400 | 3 |
|
|
||||||
| 3 | `learn/motors` | Set up the motors | Find ports (`lerobot-find-port`), set IDs (`lerobot-setup-motors`) for both arms | `so101.mdx` motor config + `cheat-sheet.mdx` | ~150 | 4 |
|
|
||||||
| 4 | `learn/calibrate` | Calibrate | Run `lerobot-calibrate` on both arms; know where calibration files live | `so101.mdx` + `il_robots.mdx` calibration sections | ~120 | 5 |
|
|
||||||
| 5 | `learn/teleoperate` | Teleoperate your arms | Run `lerobot-teleoperate` — the first "wow" moment, deliberately **before** cameras | `il_robots.mdx` teleop (CLI only) | ~80 | 6 |
|
|
||||||
| 6 | `learn/cameras` | Connect your cameras | Detect with `lerobot-find-cameras`, position wrist + front cams, teleop with camera view | `cameras.mdx` + `il_robots.mdx` + AGENT_GUIDE §5.1 placement tips | ~120 | 7 |
|
|
||||||
| 7 | `learn/record` | Record a dataset | Record 50 episodes of one task with `lerobot-record`; keyboard controls, resume, Hub push | `il_robots.mdx` record + AGENT_GUIDE §5.7 defaults ("50 episodes, one task, fixed camera") + §5.2 | ~180 | 8 |
|
|
||||||
| 8 | `learn/train` | Train your first policy (ACT) | Launch `lerobot-train` with sane ACT defaults; know how long to wait and when to stop. Tabs: local GPU / Colab / HF Jobs (`<hfoptions id="train_env">`) | `il_robots.mdx` train + AGENT_GUIDE §7.1, §7.3, §7.7 | ~150 | 9 |
|
|
||||||
| 9 | `learn/evaluate` | Evaluate your policy | Run the policy, measure success over 10 trials, know what's "good" | `il_robots.mdx` eval + AGENT_GUIDE §8.1, §8.3, §5.8 | ~120 | 10 |
|
|
||||||
| 10 | `learn/next_steps` | Next steps | Pick a direction: better data → `data_collection_tips`; language/multi-task → `smolvla`; mobile → `lekiwi`; RL → `hilserl`; sim → `envhub`; Discord | NEW; card grid | ~60 | — |
|
|
||||||
|
|
||||||
**Parallel track — `quickstart_no_robot`:** one page mirroring steps 7–9 without hardware: pick a Hub dataset → `lerobot-dataset-viz` → train ACT (Colab button + local command) → evaluate in sim. Ends with "got a robot? start the course." Source: AGENT_GUIDE §3 Path B + `notebooks.mdx`. (Blessed dataset/env to be decided — §10 Q8.)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. New pages — content briefs
|
|
||||||
|
|
||||||
**Hub / decision pages:**
|
|
||||||
|
|
||||||
- **`choose_a_robot`** — Comparison table of all 11 robots: photo, type (arm / mobile / humanoid / hand), DoF, motor family (Feetech / Dynamixel / Damiao), approx. price, sourcing (kit / 3D-print / buy), teleop options, docs maturity. Below: short decision paragraphs ("first robot → SO-101", "mobile manipulation → LeKiwi", "research platform → Reachy 2 / G1"). Ends with a card into the course.
|
|
||||||
- **`teleoperators_overview`** — Pairing matrix: all 15 teleoperators × compatible robots, with the exact `--teleop.type=` value for each. Inline usage sections for **keyboard** and **gamepad** (currently undocumented); leader arms link to their robot pages; phone links to `phone_teleop`.
|
|
||||||
- **`choose_a_policy`** — Table from AGENT_GUIDE §6.1: policy, params, min VRAM tier, single/multi-task, language-conditioned, inference speed, pretrained checkpoints, dataset-size sweet spot. Decision rules from §6.2 ("first policy → ACT; language + 16 GB → SmolVLA; ≥40 GB → π₀ / GR00T"). Legacy policies in a collapsed footnote.
|
|
||||||
- **`reward_models_overview`** — Half page: what reward models are for (success classification for RL, episode filtering, eval scoring), where they plug into HIL-SERL, 3-row table (SARM / ROBOMETER / TOPReward: input modality, output, use case).
|
|
||||||
- **`choose_a_benchmark`** — Table: benchmark, sim engine, # tasks, focus (long-horizon / language / generalization / bimanual), policies with reported results, GPU needs, Dockerfile availability (AGENT_GUIDE §8.2b). Guidance: match the benchmark to your policy's claim.
|
|
||||||
|
|
||||||
**Getting started:**
|
|
||||||
|
|
||||||
- **`index` (rewrite)** — Card grid replacing prose: "Build your first robot (course)" / "Quickstart without a robot" / "Choose a robot" / "Choose a policy" / "Browse datasets & models on the Hub" / "Join Discord". Keep the one-paragraph mission statement.
|
|
||||||
- **`troubleshooting`** — Symptom-organized, one H2 per failure area: installation (ffmpeg, CUDA) · ports & permissions (find-port fails, udev) · motors (wrong ID, no torque, LED codes from AGENT_GUIDE) · calibration drift · cameras (fps, USB bandwidth) · recording (crashes, resume) · training (flat loss, OOM → compute guide) · policy acts erratically (AGENT_GUIDE §5.8 signal table). Each entry: symptom → cause → fix command.
|
|
||||||
|
|
||||||
**How-to consolidations (Diátaxis splits):**
|
|
||||||
|
|
||||||
- **`record_dataset`** — Robot-agnostic deep version of course step 7: full `lerobot-record` flag reference, `<hfoptions id="robot">` per-robot commands, Python API in its own H2, Hub upload.
|
|
||||||
- **`data_collection_tips`** — AGENT_GUIDE §5 nearly verbatim: ergonomics, practice runs, consistency, the "start small" golden rule, troubleshooting signals.
|
|
||||||
- **`visualize_replay`** — `lerobot-dataset-viz`, the online Hub visualizer, `lerobot-replay`.
|
|
||||||
- **`train_policy`** — One canonical training how-to: `lerobot-train` anatomy, steps/batch/LR guidance (AGENT_GUIDE §7), resume, checkpoints, W&B, local/Colab/HF-Jobs tabs. Links out to multi-GPU, PEFT, compute guide.
|
|
||||||
- **`evaluate_policy`** — Real-robot eval protocol (n trials, success criteria, §8.3 targets), `lerobot-eval` for sim, comparing checkpoints.
|
|
||||||
- **`using_lerobot_dataset`** — Python API split out of `lerobot-dataset-v3`: load from Hub, random access, `delta_timestamps`, DataLoader, streaming, image transforms (+ `lerobot-imgtransform-viz`).
|
|
||||||
|
|
||||||
**Reference:**
|
|
||||||
|
|
||||||
- **`cli_reference`** — Replaces cheat-sheet. All 17 commands, one H2 each, grouped by workflow: *Setup* (find-port, setup-motors, setup-can, calibrate, find-cameras, info) · *Data* (record, replay, dataset-viz, edit-dataset, imgtransform-viz, find-joint-limits) · *Train* (train, train-tokenizer) · *Deploy* (eval, rollout, teleoperate). Each: one-line purpose, copy-paste SO-101 example, link to the relevant guide.
|
|
||||||
- **`diffusion`** — Standard policy-template page built from `policy_diffusion_README.md` (the only major policy with no page today).
|
|
||||||
- **`legacy_policies`** — VQ-BeT + TDMPC on one page with a status banner ("maintained for reproducibility, not recommended for new projects"), minimal train/eval commands, paper links.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. Migration table
|
|
||||||
|
|
||||||
**Redirects needed: 2.** Orphan files aren't in the toctree (never built/served), so deleting them breaks no URLs. Toctree moves and retitles with unchanged slugs need no redirects.
|
|
||||||
|
|
||||||
### 7.1 Current toctree pages (76)
|
|
||||||
|
|
||||||
| Old slug | Disposition | New location | Redirect |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `index` | rewrite (card grid) | Get started | no |
|
|
||||||
| `installation` | keep (reference matrix); happy path copied to `learn/install` | Get started | no |
|
|
||||||
| `cheat-sheet` | **merge** into `cli_reference` | Reference & resources | **yes** |
|
|
||||||
| `il_robots` | **split 8 ways** → `learn/{calibrate,teleoperate,cameras,record,train,evaluate}`, `record_dataset`, `visualize_replay`, `train_policy`, `evaluate_policy` | — | **yes → `learn/index`** |
|
|
||||||
| `lelab` | move | Reference & resources | no |
|
|
||||||
| `bring_your_own_policies` | move, retitle "Add a policy" | Extending | no |
|
|
||||||
| `integrate_hardware` | move, retitle "Add a robot" | Extending | no |
|
|
||||||
| `hilserl`, `hilserl_sim`, `hil_data_collection` | move | Reinforcement learning | no |
|
|
||||||
| `multi_gpu_training`, `peft_training`, `torch_accelerators`, `rename_map` | move | Train policies | no |
|
|
||||||
| `hardware_guide` | move, retitle "Compute requirements" | Train policies | no |
|
|
||||||
| `lerobot-dataset-v3` | **split**: keeps slug as format spec; Python usage → `using_lerobot_dataset` | Datasets in depth | no |
|
|
||||||
| `porting_datasets_v3`, `language_and_recipes` | move | Collect data | no |
|
|
||||||
| `using_dataset_tools` | move, retitle "Edit datasets"; viz section → `visualize_replay` | Collect data | no |
|
|
||||||
| `tools` | move, retitle "Tool-calling columns" | Datasets in depth | no |
|
|
||||||
| `video_encoding_parameters`, `streaming_video_encoding` | move | Datasets in depth | no |
|
|
||||||
| `act`, `smolvla`, `pi0`, `pi0fast`, `pi05`, `molmoact2`, `vla_jepa`, `eo1`, `groot`, `xvla`, `multi_task_dit`, `walloss` | keep; normalize to policy template; absorb matching orphan READMEs | Policies (after `choose_a_policy`) | no |
|
|
||||||
| `sarm`, `robometer`, `topreward` | keep | Reward models (after overview) | no |
|
|
||||||
| `inference` | keep, retitle "Deploy a policy (lerobot-rollout)" | Deploy & evaluate | no |
|
|
||||||
| `async` | keep | Deploy & evaluate | no |
|
|
||||||
| `rtc` | keep; absorb `policy_rtc_README.md` | Deploy & evaluate | no |
|
|
||||||
| `envhub`, `envhub_leisaac` | keep | Simulation | no |
|
|
||||||
| `envhub_isaaclab_arena` | move | Simulation | no |
|
|
||||||
| `adding_benchmarks` | move, retitle "Add a benchmark" | Extending | no |
|
|
||||||
| `libero`, `libero_plus`, `metaworld`, `robotwin`, `robocasa`, `robocerebra`, `robomme`, `vlabench` | keep | Benchmarks (after `choose_a_benchmark`) | no |
|
|
||||||
| `introduction_processors` | keep; entry-level rewrite | Processors | no |
|
|
||||||
| `processors_robots_teleop`, `env_processor`, `action_representations`, `debug_processor_pipeline` | keep | Processors | no |
|
|
||||||
| `implement_your_own_processor` | move | Extending | no |
|
|
||||||
| `so101` | **split**: assembly → `learn/assemble`, motors → `learn/motors`, calibration → `learn/calibrate`; page becomes spec/sourcing/overview linking into the course | Robots | no (slug unchanged) |
|
|
||||||
| `so100` | keep + legacy banner ("SO-101 is the current generation") | Robots | no |
|
|
||||||
| `koch`, `lekiwi`, `hope_jr`, `reachy2`, `unitree_g1`, `earthrover_mini_plus`, `omx`, `openarm`, `rebot_b601` | keep; normalize to robot template | Robots (after `choose_a_robot`) | no |
|
|
||||||
| `phone_teleop` | keep | Teleoperators (after overview) | no |
|
|
||||||
| `cameras` | keep; expand (ZMQ + Reachy 2 backends) | Cameras & motors | no |
|
|
||||||
| `feetech`, `damiao` | move; document `lerobot-setup-can` in damiao | Cameras & motors | no |
|
|
||||||
| `notebooks`, `backwardcomp` | move | Reference & resources | no |
|
|
||||||
| `contributing` | keep | Community | no |
|
|
||||||
|
|
||||||
### 7.2 Orphan files (14 — delete from `docs/source/`, no redirects needed)
|
|
||||||
|
|
||||||
Canonical copies remain in `src/lerobot/policies/*/README.md`. Verified: the docs/source copies are byte-identical hand-copies with **no sync script anywhere in the repo** — deletion is safe once unique content is folded in.
|
|
||||||
|
|
||||||
| Orphan file | Disposition |
|
|
||||||
|---|---|
|
|
||||||
| `policy_act_README.md`, `policy_groot_README.md`, `policy_molmoact2_README.md`, `policy_multi_task_dit_README.md`, `policy_pi0_README.md`, `policy_pi05_README.md`, `policy_sarm_README.md`, `policy_smolvla_README.md`, `policy_vla_jepa_README.md`, `policy_walloss_README.md`, `policy_rtc_README.md` | Delete; fold any unique content into the corresponding `.mdx` page |
|
|
||||||
| `policy_diffusion_README.md` | Delete; content seeds the **new** `diffusion.mdx` |
|
|
||||||
| `policy_tdmpc_README.md`, `policy_vqbet_README.md` | Delete; content seeds the **new** `legacy_policies.mdx` |
|
|
||||||
|
|
||||||
### 7.3 New files (27)
|
|
||||||
|
|
||||||
`learn/index`, `learn/install`, `learn/assemble`, `learn/motors`, `learn/calibrate`, `learn/teleoperate`, `learn/cameras`, `learn/record`, `learn/train`, `learn/evaluate`, `learn/next_steps`, `quickstart_no_robot`, `troubleshooting`, `record_dataset`, `data_collection_tips`, `visualize_replay`, `train_policy`, `evaluate_policy`, `using_lerobot_dataset`, `choose_a_robot`, `teleoperators_overview`, `choose_a_policy`, `diffusion`, `legacy_policies`, `reward_models_overview`, `choose_a_benchmark`, `cli_reference` — plus `_redirects.yml`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. Conventions & style guide
|
|
||||||
|
|
||||||
**Naming**
|
|
||||||
- Product names in prose: **SO-101**, SO-100, LeKiwi, π₀ (display) / `pi0` (slug/code), HIL-SERL, LeRobotDataset. Never "so101"/"SO101" in prose.
|
|
||||||
- Slugs: lowercase with underscores for new files (`choose_a_robot`); course under `learn/`. Existing hyphenated slugs (`lerobot-dataset-v3`, `cheat-sheet`) are grandfathered, not propagated.
|
|
||||||
- Titles: sentence case. How-to titles start with a verb ("Record a dataset"); reference titles are nouns ("CLI reference"); course titles are numbered ("3. Set up the motors").
|
|
||||||
- Placeholders in commands: consistent `<angle_brackets>` (e.g. `--robot.port=<your_port>`).
|
|
||||||
|
|
||||||
**Components**
|
|
||||||
- `<hfoptions>` with **fixed, site-wide ids** so a choice persists across pages: `id="robot"` · `id="os"` · `id="install"` · `id="train_env"` (local/colab/jobs). Tabs only for structurally parallel content; never hide unique content inside a tab.
|
|
||||||
- `<Tip>` for actionable advice; `<Tip warning={true}>` for anything that can damage hardware, lose data, or cost money. Standardize on `<Tip>` in `.mdx`; `> [!NOTE]` only in plain `.md`. Max ~3 tips per page.
|
|
||||||
- Card grids only on `index`, `learn/next_steps`, and section overview pages.
|
|
||||||
- Every assembly/calibration video must have the same steps in text below it (accessibility + searchability).
|
|
||||||
|
|
||||||
**Page templates & length budgets**
|
|
||||||
|
|
||||||
| Type | Skeleton | Budget |
|
|
||||||
|---|---|---|
|
|
||||||
| Course page | goal → "you need" → numbered steps w/ checkpoints → mini-troubleshooting → Next | ≤200 lines (assembly exempt) |
|
|
||||||
| How-to | problem statement → minimal working command → variations → edge cases → related links | ≤300 |
|
|
||||||
| Robot page | hero photo → spec table → buy/build → calibration quirks → compatible teleops → known issues | ≤400 |
|
|
||||||
| Policy page | summary card (params, VRAM, license, paper, checkpoints) → when to use → install → train → finetune → results → citation | ≤250 |
|
|
||||||
| Concept page | what & why in 3 sentences → diagram → details → pointers to how-tos | ≤400 |
|
|
||||||
|
|
||||||
Hard cap **500 lines** for any page → must split (`il_robots` at ~650 and `hilserl` at ~950 are the cautionary examples).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 9. Phased rollout plan
|
|
||||||
|
|
||||||
| Phase | Scope | PRs | Effort | Parallel? |
|
|
||||||
|---|---|---|---|---|
|
|
||||||
| **0. Sign-off** | This proposal reviewed; open questions (§10) decided | — | ~0.5 wk elapsed | — |
|
|
||||||
| **1. Skeleton & cleanup** | New `_toctree.yml` (all KEPT/MOVED pages in final positions), `_redirects.yml`, delete 14 orphans, retitles, SO-100 legacy banner | 1 mechanical PR (no content changes — easy review) | 1–2 days | no |
|
|
||||||
| **2. Beginner course** | 11 `learn/` pages, `so101` + `il_robots` splits, `quickstart_no_robot`, `troubleshooting`; `il_robots` redirect ships here | 3 PRs (steps 0–4 / 5–10 / quickstart+troubleshooting) | 5–8 days | partially — single author for voice; troubleshooting separable |
|
|
||||||
| **3. Hub / decision pages** | `choose_a_robot`, `choose_a_policy`, `choose_a_benchmark`, `teleoperators_overview`, `reward_models_overview`, `index` rewrite | 6 independent PRs | 4–6 days | **fully** — good first issues per area owner |
|
|
||||||
| **4. How-to consolidation** | `record_dataset`, `data_collection_tips`, `visualize_replay`, `train_policy`, `evaluate_policy`, dataset split, cheat-sheet → `cli_reference` (+ its redirect) | 3–4 PRs (data / train / deploy / datasets) | 5–7 days | parallel across clusters |
|
|
||||||
| **5. Reference fill-in** | `diffusion`, `legacy_policies`, ZMQ camera docs, 6 undocumented CLI commands, keyboard/gamepad teleop content, policy/robot pages normalized to templates | many small PRs | 4–6 days | fully — community-friendly |
|
|
||||||
|
|
||||||
Total ≈ **4–5 person-weeks**. Phases 2 and 3 can run concurrently after Phase 1 merges. Each URL-breaking change ships in the same PR as its redirect.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 10. Open questions for maintainers
|
|
||||||
|
|
||||||
1. **Legacy policies (TDMPC, VQ-BeT):** combined `legacy_policies` page (proposed), full individual pages, or src READMEs only?
|
|
||||||
2. **Orphan READMEs:** verified byte-identical hand-copies of `src/lerobot/policies/*/README.md` with no sync script — OK to delete from `docs/source/` in Phase 1? Which copy is canonical going forward (proposed: src)?
|
|
||||||
3. **AGENT_GUIDE.md future:** large parts become docs pages (§5–§8). Slim it to a pointer file for AI agents (avoids divergence), or keep self-contained and accept dual maintenance?
|
|
||||||
4. **Course URL namespace:** `learn/` (proposed) vs `getting_started/` vs `tutorial/` — permanent once shipped, decide before Phase 2.
|
|
||||||
5. **LeLab status:** first-party and maintained? If yes, add a "prefer a GUI?" callout inside the course (teleop/record steps); if experimental, keep in Reference & resources as proposed.
|
|
||||||
6. **SO-100 messaging:** proposed banner "SO-101 is the current generation" + full docs retained. Strong enough, or formally deprecate?
|
|
||||||
7. **HIL-SERL split:** `hilserl.mdx` is ~950 lines. Split into "Setup & demonstrations" + "Reward classifier & training" in a later phase, or keep monolithic for its expert audience?
|
|
||||||
8. **Blessed no-hardware quickstart:** PushT, a Hub SO-101 dataset, gym-hil, or an EnvHub env? Determines `quickstart_no_robot` content.
|
|
||||||
9. **`<hfoptions>` id taxonomy:** agree on the canonical ids (`robot`, `os`, `install`, `train_env`) — they persist site-wide once shipped.
|
|
||||||
10. **Reward models placement:** standalone section after Policies (proposed) vs nested under Reinforcement learning — depends on whether SARM/ROBOMETER/TOPReward are positioned as general eval tools or RL components.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Appendix: sources & verification
|
|
||||||
|
|
||||||
- Current nav: `docs/source/_toctree.yml` (76 pages, 15 sections — verified June 2026).
|
|
||||||
- File census: 90 content files in `docs/source/`; orphans confirmed by toctree diff; orphan↔src byte-identity confirmed by `diff`; absence of a sync script confirmed by repo-wide grep.
|
|
||||||
- CLI commands: `pyproject.toml [project.scripts]` (17 entries).
|
|
||||||
- AGENT_GUIDE.md section references (§5 data tips, §6 policy choice, §7 training duration, §8 evaluation) verified against its headings.
|
|
||||||
- Patterns referenced: Diátaxis (diataxis.fr) · FastAPI docs · Ultralytics YOLO docs · HF Transformers docs (doc-builder feature ceiling).
|
|
||||||
@@ -58,7 +58,7 @@ test-act-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=4 \
|
--steps=4 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
@@ -96,7 +96,7 @@ test-diffusion-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=2 \
|
--steps=2 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
@@ -126,7 +126,7 @@ test-tdmpc-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=2 \
|
--steps=2 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
@@ -161,7 +161,7 @@ test-smolvla-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=4 \
|
--steps=4 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
|
|||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1
|
--eval.batch_size=1
|
||||||
|
|
||||||
|
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||||
|
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||||
|
# backend, so it does not require a real model checkpoint or GPU.
|
||||||
|
annotation-e2e:
|
||||||
|
uv run python -m tests.annotations.run_e2e_smoke
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ action = model.select_action(obs)
|
|||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
|
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
|
||||||
|
|
||||||
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
|
||||||
|
|
||||||
@@ -101,11 +101,13 @@ lerobot-train \
|
|||||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||||
```
|
```
|
||||||
|
|
||||||
| Category | Models |
|
| Category | Models |
|
||||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
|
||||||
|
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
|
||||||
|
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
|
||||||
|
|
||||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||||
|
|
||||||
@@ -133,6 +135,8 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
|||||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||||
|
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||||
|
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
@@ -140,7 +144,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
|
|||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{cadene2024lerobot,
|
@misc{cadene2024lerobot,
|
||||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
|
||||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||||
year = {2024}
|
year = {2024}
|
||||||
|
|||||||
@@ -1,172 +0,0 @@
|
|||||||
- sections:
|
|
||||||
- local: index
|
|
||||||
title: LeRobot
|
|
||||||
- local: installation
|
|
||||||
title: Installation
|
|
||||||
- local: cheat-sheet
|
|
||||||
title: Cheat sheet
|
|
||||||
title: Get started
|
|
||||||
- sections:
|
|
||||||
- local: il_robots
|
|
||||||
title: Imitation Learning for Robots
|
|
||||||
- local: bring_your_own_policies
|
|
||||||
title: Adding a Policy
|
|
||||||
- local: integrate_hardware
|
|
||||||
title: Bring Your Own Hardware
|
|
||||||
- local: hilserl
|
|
||||||
title: Train a Robot with RL
|
|
||||||
- local: hilserl_sim
|
|
||||||
title: Train RL in Simulation
|
|
||||||
- local: multi_gpu_training
|
|
||||||
title: Multi GPU training
|
|
||||||
- local: hil_data_collection
|
|
||||||
title: Human In the Loop Data Collection
|
|
||||||
- local: peft_training
|
|
||||||
title: Training with PEFT (e.g., LoRA)
|
|
||||||
- local: rename_map
|
|
||||||
title: Using Rename Map and Empty Cameras
|
|
||||||
title: "Tutorials"
|
|
||||||
- sections:
|
|
||||||
- local: hardware_guide
|
|
||||||
title: Compute Hardware Guide
|
|
||||||
- local: torch_accelerators
|
|
||||||
title: PyTorch accelerators
|
|
||||||
title: "Compute & Hardware"
|
|
||||||
- sections:
|
|
||||||
- local: lerobot-dataset-v3
|
|
||||||
title: Using LeRobotDataset
|
|
||||||
- local: porting_datasets_v3
|
|
||||||
title: Porting Large Datasets
|
|
||||||
- local: using_dataset_tools
|
|
||||||
title: Using the Dataset Tools
|
|
||||||
- local: language_and_recipes
|
|
||||||
title: Language Columns and Recipes
|
|
||||||
- local: tools
|
|
||||||
title: Tools
|
|
||||||
- local: video_encoding_parameters
|
|
||||||
title: Video encoding parameters
|
|
||||||
- local: streaming_video_encoding
|
|
||||||
title: Streaming Video Encoding
|
|
||||||
title: "Datasets"
|
|
||||||
- sections:
|
|
||||||
- local: act
|
|
||||||
title: ACT
|
|
||||||
- local: smolvla
|
|
||||||
title: SmolVLA
|
|
||||||
- local: pi0
|
|
||||||
title: π₀ (Pi0)
|
|
||||||
- local: pi0fast
|
|
||||||
title: π₀-FAST (Pi0Fast)
|
|
||||||
- local: pi05
|
|
||||||
title: π₀.₅ (Pi05)
|
|
||||||
- local: eo1
|
|
||||||
title: EO-1
|
|
||||||
- local: groot
|
|
||||||
title: NVIDIA GR00T N1.5
|
|
||||||
- local: xvla
|
|
||||||
title: X-VLA
|
|
||||||
- local: multi_task_dit
|
|
||||||
title: Multitask DiT Policy
|
|
||||||
- local: walloss
|
|
||||||
title: WALL-OSS
|
|
||||||
title: "Policies"
|
|
||||||
- sections:
|
|
||||||
- local: sarm
|
|
||||||
title: SARM
|
|
||||||
title: "Reward Models"
|
|
||||||
- sections:
|
|
||||||
- local: inference
|
|
||||||
title: Policy Deployment (lerobot-rollout)
|
|
||||||
- local: async
|
|
||||||
title: Use Async Inference
|
|
||||||
- local: rtc
|
|
||||||
title: Real-Time Chunking (RTC)
|
|
||||||
title: "Inference"
|
|
||||||
- sections:
|
|
||||||
- local: envhub
|
|
||||||
title: Environments from the Hub
|
|
||||||
- local: envhub_leisaac
|
|
||||||
title: Control & Train Robots in Sim (LeIsaac)
|
|
||||||
title: "Simulation"
|
|
||||||
- sections:
|
|
||||||
- local: adding_benchmarks
|
|
||||||
title: Adding a New Benchmark
|
|
||||||
- local: libero
|
|
||||||
title: LIBERO
|
|
||||||
- local: libero_plus
|
|
||||||
title: LIBERO-plus
|
|
||||||
- local: metaworld
|
|
||||||
title: Meta-World
|
|
||||||
- local: robotwin
|
|
||||||
title: RoboTwin 2.0
|
|
||||||
- local: robocasa
|
|
||||||
title: RoboCasa365
|
|
||||||
- local: robocerebra
|
|
||||||
title: RoboCerebra
|
|
||||||
- local: robomme
|
|
||||||
title: RoboMME
|
|
||||||
- local: envhub_isaaclab_arena
|
|
||||||
title: NVIDIA IsaacLab Arena Environments
|
|
||||||
- local: vlabench
|
|
||||||
title: VLABench
|
|
||||||
title: "Benchmarks"
|
|
||||||
- sections:
|
|
||||||
- local: introduction_processors
|
|
||||||
title: Introduction to Robot Processors
|
|
||||||
- local: debug_processor_pipeline
|
|
||||||
title: Debug your processor pipeline
|
|
||||||
- local: implement_your_own_processor
|
|
||||||
title: Implement your own processor
|
|
||||||
- local: processors_robots_teleop
|
|
||||||
title: Processors for Robots and Teleoperators
|
|
||||||
- local: env_processor
|
|
||||||
title: Environment Processors
|
|
||||||
- local: action_representations
|
|
||||||
title: Action Representations
|
|
||||||
title: "Robot Processors"
|
|
||||||
- sections:
|
|
||||||
- local: so101
|
|
||||||
title: SO-101
|
|
||||||
- local: so100
|
|
||||||
title: SO-100
|
|
||||||
- local: koch
|
|
||||||
title: Koch v1.1
|
|
||||||
- local: lekiwi
|
|
||||||
title: LeKiwi
|
|
||||||
- local: hope_jr
|
|
||||||
title: Hope Jr
|
|
||||||
- local: reachy2
|
|
||||||
title: Reachy 2
|
|
||||||
- local: unitree_g1
|
|
||||||
title: Unitree G1
|
|
||||||
- local: earthrover_mini_plus
|
|
||||||
title: Earth Rover Mini
|
|
||||||
- local: omx
|
|
||||||
title: OMX
|
|
||||||
- local: openarm
|
|
||||||
title: OpenArm
|
|
||||||
- local: rebot_b601
|
|
||||||
title: reBot B601-DM
|
|
||||||
title: "Robots"
|
|
||||||
- sections:
|
|
||||||
- local: phone_teleop
|
|
||||||
title: Phone
|
|
||||||
title: "Teleoperators"
|
|
||||||
- sections:
|
|
||||||
- local: cameras
|
|
||||||
title: Cameras
|
|
||||||
title: "Sensors"
|
|
||||||
- sections:
|
|
||||||
- local: notebooks
|
|
||||||
title: Notebooks
|
|
||||||
- local: feetech
|
|
||||||
title: Updating Feetech Firmware
|
|
||||||
- local: damiao
|
|
||||||
title: Damiao Motors and CAN Bus
|
|
||||||
title: "Resources"
|
|
||||||
- sections:
|
|
||||||
- local: contributing
|
|
||||||
title: Contribute to LeRobot
|
|
||||||
- local: backwardcomp
|
|
||||||
title: Backward compatibility
|
|
||||||
title: "About"
|
|
||||||
+141
-171
@@ -1,214 +1,184 @@
|
|||||||
# LeRobot documentation table of contents
|
|
||||||
#
|
|
||||||
# Ordering principle: gentle onboarding first, advanced/custom work last.
|
|
||||||
# Within each top-level section the same rule applies — concept/overview pages
|
|
||||||
# before reference/per-item pages.
|
|
||||||
#
|
|
||||||
# Pages marked "NEW (to create)" do not yet exist as .mdx files; they are
|
|
||||||
# placeholders for the redesign and must be authored before the docs build.
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: index
|
- local: index
|
||||||
title: 🤗 LeRobot
|
title: LeRobot
|
||||||
- local: quickstart # NEW (to create) — 15-min zero-to-trained-ACT path
|
|
||||||
title: Quickstart
|
|
||||||
- local: installation
|
- local: installation
|
||||||
title: Installation
|
title: Installation
|
||||||
- local: core_concepts # NEW (to create) — datasets, policies, processors, robots, envs in one mental model
|
|
||||||
title: Core concepts
|
|
||||||
- local: cheat-sheet
|
- local: cheat-sheet
|
||||||
title: Command cheat sheet
|
title: Cheat sheet
|
||||||
title: Get started
|
title: Get started
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: il_robots
|
- local: il_robots
|
||||||
title: Imitation learning end-to-end
|
title: Imitation Learning for Robots
|
||||||
|
- local: lelab
|
||||||
|
title: LeLab - Lerobot GUI
|
||||||
|
- local: bring_your_own_policies
|
||||||
|
title: Adding a Policy
|
||||||
|
- local: integrate_hardware
|
||||||
|
title: Bring Your Own Hardware
|
||||||
|
- local: hilserl
|
||||||
|
title: Train a Robot with RL
|
||||||
|
- local: hilserl_sim
|
||||||
|
title: Train RL in Simulation
|
||||||
|
- local: multi_gpu_training
|
||||||
|
title: Multi GPU training
|
||||||
- local: hil_data_collection
|
- local: hil_data_collection
|
||||||
title: Human-in-the-loop data collection
|
title: Human In the Loop Data Collection
|
||||||
- local: inference
|
- local: peft_training
|
||||||
title: Deploying a trained policy
|
title: Training with PEFT (e.g., LoRA)
|
||||||
- local: rename_map
|
- local: rename_map
|
||||||
title: Matching dataset keys to a policy (rename map)
|
title: Using Rename Map and Empty Cameras
|
||||||
title: Your first project
|
title: "Tutorials"
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: hardware_guide
|
- local: hardware_guide
|
||||||
title: Compute hardware guide
|
title: Compute Hardware Guide
|
||||||
- local: torch_accelerators
|
- local: torch_accelerators
|
||||||
title: PyTorch accelerators
|
title: PyTorch accelerators
|
||||||
- local: multi_gpu_training
|
title: "Compute & Hardware"
|
||||||
title: Multi-GPU training
|
|
||||||
- local: peft_training
|
|
||||||
title: Parameter-efficient fine-tuning (LoRA)
|
|
||||||
title: Training
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: lerobot-dataset-v3
|
- local: lerobot-dataset-v3
|
||||||
title: Using LeRobotDataset
|
title: Using LeRobotDataset
|
||||||
|
- local: porting_datasets_v3
|
||||||
|
title: Porting Large Datasets
|
||||||
- local: using_dataset_tools
|
- local: using_dataset_tools
|
||||||
title: Dataset tools
|
title: Using the Dataset Tools
|
||||||
- local: language_and_recipes
|
- local: language_and_recipes
|
||||||
title: Language columns & recipes
|
title: Language Columns and Recipes
|
||||||
- local: tools
|
- local: tools
|
||||||
title: Tool calls in datasets
|
title: Tools
|
||||||
|
- local: annotation_pipeline
|
||||||
|
title: Annotation Pipeline
|
||||||
- local: video_encoding_parameters
|
- local: video_encoding_parameters
|
||||||
title: Video encoding parameters
|
title: Video encoding parameters
|
||||||
- local: streaming_video_encoding
|
- local: streaming_video_encoding
|
||||||
title: Streaming video encoding
|
title: Streaming Video Encoding
|
||||||
- local: porting_datasets_v3
|
title: "Datasets"
|
||||||
title: Porting datasets to v3
|
|
||||||
title: Datasets
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: policies_overview # NEW (to create) — concept hub + "choose a policy" decision guide
|
- local: act
|
||||||
title: Choosing a policy
|
title: ACT
|
||||||
- sections:
|
- local: smolvla
|
||||||
- local: act
|
title: SmolVLA
|
||||||
title: ACT
|
- local: pi0
|
||||||
- local: smolvla
|
title: π₀ (Pi0)
|
||||||
title: SmolVLA
|
- local: pi0fast
|
||||||
- local: pi0
|
title: π₀-FAST (Pi0Fast)
|
||||||
title: π₀ (Pi0)
|
- local: pi05
|
||||||
- local: pi0fast
|
title: π₀.₅ (Pi05)
|
||||||
title: π₀-FAST
|
- local: molmoact2
|
||||||
- local: pi05
|
title: MolmoAct2
|
||||||
title: π₀.₅ (Pi05)
|
- local: vla_jepa
|
||||||
- local: eo1
|
title: VLA-JEPA
|
||||||
title: EO-1
|
- local: eo1
|
||||||
- local: groot
|
title: EO-1
|
||||||
title: NVIDIA GR00T N1.5
|
- local: groot
|
||||||
- local: xvla
|
title: NVIDIA GR00T N1.5
|
||||||
title: X-VLA
|
- local: xvla
|
||||||
- local: walloss
|
title: X-VLA
|
||||||
title: WALL-OSS
|
- local: multi_task_dit
|
||||||
- local: multi_task_dit
|
title: Multitask DiT Policy
|
||||||
title: Multitask DiT
|
- local: walloss
|
||||||
title: Policy reference
|
title: WALL-OSS
|
||||||
title: Policies
|
title: "Policies"
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: async
|
|
||||||
title: Async inference
|
|
||||||
- local: rtc
|
|
||||||
title: Real-time chunking (RTC)
|
|
||||||
title: Real-time deployment
|
|
||||||
|
|
||||||
- sections:
|
|
||||||
- local: hilserl
|
|
||||||
title: Train a robot with RL (HIL-SERL)
|
|
||||||
- local: hilserl_sim
|
|
||||||
title: Train RL in simulation
|
|
||||||
- local: sarm
|
- local: sarm
|
||||||
title: SARM reward model
|
title: SARM
|
||||||
title: Reinforcement learning
|
- local: robometer
|
||||||
|
title: ROBOMETER
|
||||||
|
- local: topreward
|
||||||
|
title: TOPReward
|
||||||
|
title: "Reward Models"
|
||||||
|
- sections:
|
||||||
|
- local: inference
|
||||||
|
title: Policy Deployment (lerobot-rollout)
|
||||||
|
- local: async
|
||||||
|
title: Use Async Inference
|
||||||
|
- local: rtc
|
||||||
|
title: Real-Time Chunking (RTC)
|
||||||
|
title: "Inference"
|
||||||
- sections:
|
- sections:
|
||||||
- local: envhub
|
- local: envhub
|
||||||
title: Environments from the Hub
|
title: Environments from the Hub
|
||||||
- local: envhub_leisaac
|
- local: envhub_leisaac
|
||||||
title: LeIsaac — control & train in sim
|
title: Control & Train Robots in Sim (LeIsaac)
|
||||||
|
title: "Simulation"
|
||||||
|
- sections:
|
||||||
|
- local: adding_benchmarks
|
||||||
|
title: Adding a New Benchmark
|
||||||
|
- local: libero
|
||||||
|
title: LIBERO
|
||||||
|
- local: libero_plus
|
||||||
|
title: LIBERO-plus
|
||||||
|
- local: metaworld
|
||||||
|
title: Meta-World
|
||||||
|
- local: robotwin
|
||||||
|
title: RoboTwin 2.0
|
||||||
|
- local: robocasa
|
||||||
|
title: RoboCasa365
|
||||||
|
- local: robocerebra
|
||||||
|
title: RoboCerebra
|
||||||
|
- local: robomme
|
||||||
|
title: RoboMME
|
||||||
- local: envhub_isaaclab_arena
|
- local: envhub_isaaclab_arena
|
||||||
title: NVIDIA IsaacLab Arena environments
|
title: NVIDIA IsaacLab Arena Environments
|
||||||
- sections:
|
- local: vlabench
|
||||||
- local: libero
|
title: VLABench
|
||||||
title: LIBERO
|
title: "Benchmarks"
|
||||||
- local: libero_plus
|
|
||||||
title: LIBERO-plus
|
|
||||||
- local: metaworld
|
|
||||||
title: Meta-World
|
|
||||||
- local: robotwin
|
|
||||||
title: RoboTwin 2.0
|
|
||||||
- local: robocasa
|
|
||||||
title: RoboCasa365
|
|
||||||
- local: robocerebra
|
|
||||||
title: RoboCerebra
|
|
||||||
- local: robomme
|
|
||||||
title: RoboMME
|
|
||||||
- local: vlabench
|
|
||||||
title: VLABench
|
|
||||||
title: Benchmark suites
|
|
||||||
title: Simulation & benchmarks
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: introduction_processors
|
- local: introduction_processors
|
||||||
title: Introduction to processors
|
title: Introduction to Robot Processors
|
||||||
- local: processors_robots_teleop
|
|
||||||
title: Processors for robots & teleoperators
|
|
||||||
- local: env_processor
|
|
||||||
title: Environment processors
|
|
||||||
- local: action_representations
|
|
||||||
title: Action representations
|
|
||||||
- local: debug_processor_pipeline
|
- local: debug_processor_pipeline
|
||||||
title: Debugging a pipeline
|
title: Debug your processor pipeline
|
||||||
- local: implement_your_own_processor
|
- local: implement_your_own_processor
|
||||||
title: Implementing your own processor
|
title: Implement your own processor
|
||||||
title: Processors
|
- local: processors_robots_teleop
|
||||||
|
title: Processors for Robots and Teleoperators
|
||||||
|
- local: env_processor
|
||||||
|
title: Environment Processors
|
||||||
|
- local: action_representations
|
||||||
|
title: Action Representations
|
||||||
|
title: "Robot Processors"
|
||||||
- sections:
|
- sections:
|
||||||
- sections:
|
- local: so101
|
||||||
- local: so101
|
title: SO-101
|
||||||
title: SO-101
|
- local: so100
|
||||||
- local: so100
|
title: SO-100
|
||||||
title: SO-100
|
- local: koch
|
||||||
- local: koch
|
title: Koch v1.1
|
||||||
title: Koch v1.1
|
- local: lekiwi
|
||||||
- local: omx
|
title: LeKiwi
|
||||||
title: OMX
|
- local: hope_jr
|
||||||
- local: openarm
|
title: Hope Jr
|
||||||
title: OpenArm
|
- local: reachy2
|
||||||
title: Low-cost arms
|
title: Reachy 2
|
||||||
- sections:
|
- local: unitree_g1
|
||||||
- local: lekiwi
|
title: Unitree G1
|
||||||
title: LeKiwi
|
- local: earthrover_mini_plus
|
||||||
- local: earthrover_mini_plus
|
title: Earth Rover Mini
|
||||||
title: Earth Rover Mini
|
- local: omx
|
||||||
title: Mobile platforms
|
title: OMX
|
||||||
- sections:
|
- local: openarm
|
||||||
- local: hope_jr
|
title: OpenArm
|
||||||
title: Hope Jr
|
- local: rebot_b601
|
||||||
- local: reachy2
|
title: reBot B601-DM
|
||||||
title: Reachy 2
|
title: "Robots"
|
||||||
- local: unitree_g1
|
- sections:
|
||||||
title: Unitree G1
|
- local: phone_teleop
|
||||||
title: Bimanual & humanoid
|
title: Phone
|
||||||
- sections:
|
title: "Teleoperators"
|
||||||
- local: rebot_b601
|
|
||||||
title: reBot B601-DM
|
|
||||||
title: Research & industrial
|
|
||||||
title: Supported robots
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: cameras
|
- local: cameras
|
||||||
title: Cameras
|
title: Cameras
|
||||||
- local: phone_teleop
|
title: "Sensors"
|
||||||
title: Phone teleoperation
|
|
||||||
- local: feetech
|
|
||||||
title: Feetech firmware update
|
|
||||||
- local: damiao
|
|
||||||
title: Damiao motors & CAN bus
|
|
||||||
title: Sensors, teleop & motors
|
|
||||||
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: integrate_hardware
|
|
||||||
title: Bring your own hardware
|
|
||||||
- local: bring_your_own_policies
|
|
||||||
title: Add a new policy
|
|
||||||
- local: adding_benchmarks
|
|
||||||
title: Add a new benchmark
|
|
||||||
title: Extend LeRobot
|
|
||||||
|
|
||||||
- sections:
|
|
||||||
- local: troubleshooting # NEW (to create) — common errors: USB, calibration drift, CUDA OOM, video decoding…
|
|
||||||
title: Troubleshooting & FAQ
|
|
||||||
- local: glossary # NEW (to create) — episode, action chunk, leader/follower, teleop, processor…
|
|
||||||
title: Glossary
|
|
||||||
- local: notebooks
|
- local: notebooks
|
||||||
title: Example notebooks
|
title: Notebooks
|
||||||
- local: backwardcomp
|
- local: feetech
|
||||||
title: Backward compatibility
|
title: Updating Feetech Firmware
|
||||||
title: Reference
|
- local: damiao
|
||||||
|
title: Damiao Motors and CAN Bus
|
||||||
|
title: "Resources"
|
||||||
- sections:
|
- sections:
|
||||||
- local: contributing
|
- local: contributing
|
||||||
title: Contributing to LeRobot
|
title: Contribute to LeRobot
|
||||||
title: About
|
- local: backwardcomp
|
||||||
|
title: Backward compatibility
|
||||||
|
title: "About"
|
||||||
|
|||||||
@@ -0,0 +1,291 @@
|
|||||||
|
# Annotation Pipeline
|
||||||
|
|
||||||
|
`lerobot-annotate` watches each episode's video with a vision-language
|
||||||
|
model (VLM) and writes natural-language annotations back into your
|
||||||
|
dataset. It fills the two language columns from the
|
||||||
|
[Language Columns and Recipes](./language_and_recipes) page —
|
||||||
|
`language_persistent` and `language_events` — straight into
|
||||||
|
`data/chunk-*/file-*.parquet`.
|
||||||
|
|
||||||
|
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
|
||||||
|
memory, interjections, speech, and visual Q&A that a policy can be
|
||||||
|
trained on.
|
||||||
|
|
||||||
|
## How it fits together
|
||||||
|
|
||||||
|
```text
|
||||||
|
your dataset lerobot-annotate
|
||||||
|
(LeRobot v3.1)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────┐
|
||||||
|
│ read episodes │
|
||||||
|
└──────────────────────────┬──────────────────────────┘
|
||||||
|
│
|
||||||
|
┌────────────────────┼────────────────────┐
|
||||||
|
▼ ▼ ▼
|
||||||
|
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
|
||||||
|
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
|
||||||
|
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
|
||||||
|
└────────────────────┼─────────────────────┘
|
||||||
|
│ each module stages raw JSONL
|
||||||
|
▼ into .annotate_staging/
|
||||||
|
┌─────────────────┐
|
||||||
|
│ validator │ ◀── checks everything
|
||||||
|
└────────┬────────┘
|
||||||
|
▼
|
||||||
|
┌─────────────────┐
|
||||||
|
│ writer │
|
||||||
|
└────────┬────────┘
|
||||||
|
▼
|
||||||
|
data/chunk-*/file-*.parquet
|
||||||
|
(+ meta/info.json tools)
|
||||||
|
```
|
||||||
|
|
||||||
|
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
|
||||||
|
VLM. Each module stages its output to disk, a validator checks it, and a
|
||||||
|
single writer rewrites the dataset shards in place.
|
||||||
|
|
||||||
|
## What the pipeline produces
|
||||||
|
|
||||||
|
Each module emits a few kinds of annotation ("styles"), routed to one of
|
||||||
|
the two language columns:
|
||||||
|
|
||||||
|
| Style / atom | Column | Module |
|
||||||
|
| ------------------------------------------- | --------------------- | --------------- |
|
||||||
|
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||||
|
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||||
|
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||||
|
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
|
||||||
|
| `interjection` | `language_events` | `interjections` |
|
||||||
|
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
|
||||||
|
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||||
|
|
||||||
|
### How subtasks are generated
|
||||||
|
|
||||||
|
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
|
||||||
|
it uses a two-step **describe → segment** flow:
|
||||||
|
|
||||||
|
1. **Describe** — the VLM narrates only what it actually sees in the
|
||||||
|
chosen camera (no guessing about the task).
|
||||||
|
2. **Segment** — that description is fed back in, and the VLM splits the
|
||||||
|
episode into consecutive atomic subtasks.
|
||||||
|
|
||||||
|
Both passes see the episode as **timestamped contact sheets** — frames
|
||||||
|
sampled at `frames_per_second` (0.5s by default) and packed into JPEG
|
||||||
|
grids with each frame's time burned into its corner, so the VLM cites
|
||||||
|
exact boundary times directly. This is far cheaper in vision tokens than
|
||||||
|
one image per frame, so the sampling can stay dense; episodes longer than
|
||||||
|
`max_frames_per_prompt` are split into windows at the same density and
|
||||||
|
merged. Both prompts also carry a causal **event-boundary** definition (a
|
||||||
|
new event starts when an object becomes held / is released / reaches a new
|
||||||
|
location / a lid changes state / contents move) to sharpen where cuts land.
|
||||||
|
|
||||||
|
The resulting spans are then stitched into a gap-free, full-episode
|
||||||
|
cover, so **every frame has exactly one active subtask**. See
|
||||||
|
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||||
|
for the production settings (single camera, timestamped contact sheets,
|
||||||
|
auto-windowed subtask generation).
|
||||||
|
|
||||||
|
### Tools
|
||||||
|
|
||||||
|
The writer does **not** add a `tools` column to the parquet. The tool
|
||||||
|
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
|
||||||
|
After every run, the pipeline makes sure the canonical `say` schema is in
|
||||||
|
that list, keeping any tools you declared beforehand.
|
||||||
|
|
||||||
|
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
|
||||||
|
pipeline preserves whatever is already there. That makes the tool visible
|
||||||
|
to the chat template, so the model can learn to _generate_ the call. The
|
||||||
|
runtime layer that actually _executes_ a generated call (the `Tool`
|
||||||
|
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
|
||||||
|
this PR — the [Tools](./tools) doc marks those pieces as
|
||||||
|
not-yet-implemented.
|
||||||
|
|
||||||
|
## Running on Hugging Face Jobs
|
||||||
|
|
||||||
|
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
|
||||||
|
The repo ships a launcher script you copy and tweak for your dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||||
|
```
|
||||||
|
|
||||||
|
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||||
|
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
|
||||||
|
that:
|
||||||
|
|
||||||
|
1. installs `lerobot` (from `main`) plus the annotation extras,
|
||||||
|
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
|
||||||
|
drives it over the OpenAI-compatible API,
|
||||||
|
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||||
|
with `lerobot-annotate`,
|
||||||
|
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
|
||||||
|
back to `--repo_id` in place if you leave that unset).
|
||||||
|
|
||||||
|
To use a different dataset, model, or hub repo, edit the `CMD` block in
|
||||||
|
the script. Every flag there maps directly to a `lerobot-annotate` flag
|
||||||
|
(run `lerobot-annotate --help` for the full list).
|
||||||
|
|
||||||
|
## Key options
|
||||||
|
|
||||||
|
These are the flags you'll reach for most often. Run
|
||||||
|
`lerobot-annotate --help` for everything else; the defaults are tuned for
|
||||||
|
short manipulation episodes.
|
||||||
|
|
||||||
|
### Dataset in / out
|
||||||
|
|
||||||
|
| Flag | Default | What it does |
|
||||||
|
| ----------------- | ------- | ----------------------------------------------------------------------- |
|
||||||
|
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
|
||||||
|
| `--root` | — | Annotate a local dataset directory instead. |
|
||||||
|
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
|
||||||
|
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
|
||||||
|
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
|
||||||
|
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
|
||||||
|
|
||||||
|
### Which modules run
|
||||||
|
|
||||||
|
Every module is on by default and can be toggled independently (set to
|
||||||
|
`false` to skip it, e.g. to iterate on one module at a time):
|
||||||
|
|
||||||
|
| Flag | Default | Turns off |
|
||||||
|
| ------------------------- | ------- | ----------------------------------- |
|
||||||
|
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
|
||||||
|
| `--interjections.enabled` | `true` | interjections + speech atoms |
|
||||||
|
| `--vqa.enabled` | `true` | the VQA pairs |
|
||||||
|
|
||||||
|
### The VLM (`--vlm.*`)
|
||||||
|
|
||||||
|
| Flag | Default | What it does |
|
||||||
|
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
|
||||||
|
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
|
||||||
|
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
|
||||||
|
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
|
||||||
|
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
|
||||||
|
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
|
||||||
|
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
|
||||||
|
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
|
||||||
|
| `--vlm.temperature` | `0.2` | Sampling temperature. |
|
||||||
|
|
||||||
|
### Subtasks / plan / memory (`--plan.*`)
|
||||||
|
|
||||||
|
| Flag | Default | What it does |
|
||||||
|
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). |
|
||||||
|
| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. |
|
||||||
|
| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). |
|
||||||
|
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
|
||||||
|
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
|
||||||
|
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
|
||||||
|
| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. |
|
||||||
|
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
|
||||||
|
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
|
||||||
|
|
||||||
|
### Interjections + VQA
|
||||||
|
|
||||||
|
| Flag | Default | What it does |
|
||||||
|
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
|
||||||
|
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
|
||||||
|
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
|
||||||
|
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
|
||||||
|
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
|
||||||
|
|
||||||
|
## Contributing new modules
|
||||||
|
|
||||||
|
The pipeline is built to grow, and **contributions are very welcome** —
|
||||||
|
a brand-new module (say, trajectory traces or affordances), a new prompt
|
||||||
|
template, a smarter grounding flow, or quality fixes to the existing
|
||||||
|
`plan` / `interjections` / `vqa` modules.
|
||||||
|
|
||||||
|
Every module lives under
|
||||||
|
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
|
||||||
|
client and the keyframe cache, writes its raw output to the staging
|
||||||
|
tree, and plugs into the executor as its own phase. Got an idea? Open an
|
||||||
|
issue or PR on [the repo](https://github.com/huggingface/lerobot).
|
||||||
|
|
||||||
|
## How recipes consume the output
|
||||||
|
|
||||||
|
The annotations are meant to be read by recipes (see
|
||||||
|
[Language Columns and Recipes](./language_and_recipes)). Typically:
|
||||||
|
|
||||||
|
- low-level / high-level / memory-update branches read
|
||||||
|
`subtask` / `plan` / `memory` from `language_persistent`.
|
||||||
|
- an interjection-response branch reads `interjection` events plus the
|
||||||
|
paired speech atom (merged into one assistant turn via `tool_calls_from`)
|
||||||
|
and the matching `plan` refresh at the same timestamp.
|
||||||
|
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
|
||||||
|
`language_events`.
|
||||||
|
|
||||||
|
## Why state and events are split
|
||||||
|
|
||||||
|
Two ideas shape the design:
|
||||||
|
|
||||||
|
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
|
||||||
|
`plan`, `memory`) apply to the whole episode and answer "what's true
|
||||||
|
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
|
||||||
|
the one frame whose timestamp matches. Timestamps are copied straight
|
||||||
|
from the source parquet — never recomputed in floating point.
|
||||||
|
2. **One VLM pass.** All three modules share a single VLM client (the
|
||||||
|
OpenAI-compatible client talking to the job's vLLM server), so you pay
|
||||||
|
for one model load per dataset, not three.
|
||||||
|
|
||||||
|
## Re-running a single module
|
||||||
|
|
||||||
|
Each module stages its raw output to
|
||||||
|
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
|
||||||
|
prompt iteration cheap: re-running one module overwrites only its own
|
||||||
|
JSONL, then the writer recomposes the final parquet. Disable modules you
|
||||||
|
don't want with `--plan.enabled=false` (and likewise
|
||||||
|
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
|
||||||
|
|
||||||
|
## What the validator checks
|
||||||
|
|
||||||
|
Before the writer runs, `StagingValidator` confirms:
|
||||||
|
|
||||||
|
- every event row lands exactly on a real frame timestamp;
|
||||||
|
- no speech / interjection pairs are left orphaned;
|
||||||
|
- `plan` is refreshed at every interjection timestamp;
|
||||||
|
- `memory` rows fall on subtask boundaries (a warning, not an error);
|
||||||
|
- each VQA assistant `content` is valid JSON in one of the
|
||||||
|
bbox / keypoint / count / attribute / spatial shapes;
|
||||||
|
- every row goes to the column chosen by `column_for_style(style)`.
|
||||||
|
|
||||||
|
Any error aborts the writer. Pass `--skip_validation=true` to override
|
||||||
|
while debugging.
|
||||||
|
|
||||||
|
## Where each module's ideas come from
|
||||||
|
|
||||||
|
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||||
|
for atom granularity ("pick up one piece of lettuce", "place bowl to
|
||||||
|
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
|
||||||
|
for "how, not what" detail.
|
||||||
|
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
|
||||||
|
keep only the minimal relevant information — preserve outcomes, drop
|
||||||
|
specific attributes.
|
||||||
|
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
|
||||||
|
situated correction, specific constraint, preference. Speech is a
|
||||||
|
tool-call-only atom
|
||||||
|
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
|
||||||
|
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
|
||||||
|
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
|
||||||
|
keypoints) and Steerable VLA Policies
|
||||||
|
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
|
||||||
|
grounding. Pi0.7 also grounds answers across abstraction levels.
|
||||||
|
|
||||||
|
When improving a module, tweak its prompt template in
|
||||||
|
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
|
||||||
|
rewriting from scratch.
|
||||||
|
|
||||||
|
## Roughly how much it costs
|
||||||
|
|
||||||
|
Per episode, the pipeline makes about `max_steps` plan calls,
|
||||||
|
`max_interjections_per_episode` interjection calls, and
|
||||||
|
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
|
||||||
|
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
|
||||||
|
~50 VLM calls.
|
||||||
|
|
||||||
|
Storage stays small: `language_persistent` is at most tens of KB per
|
||||||
|
episode (parquet dictionary-encodes the one entry that repeats across
|
||||||
|
frames), and `language_events` is empty on most frames — its size scales
|
||||||
|
with the number of emissions, not `num_frames × num_emissions`.
|
||||||
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
|||||||
|
|
||||||
**Compatible teleoperators:**
|
**Compatible teleoperators:**
|
||||||
|
|
||||||
- `openarm_mini` - OpenArm Mini
|
- `bi_openarm_mini` - Bimanual OpenArm Mini
|
||||||
- `so_leader` - SO100 / SO101 leader arm
|
- `so_leader` - SO100 / SO101 leader arm
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
|
||||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
|
|||||||
--robot.right_arm_config.port=can0 \
|
--robot.right_arm_config.port=can0 \
|
||||||
--robot.right_arm_config.side=right \
|
--robot.right_arm_config.side=right \
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||||
--teleop.type=openarm_mini \
|
--teleop.type=bi_openarm_mini \
|
||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
|
|||||||
--robot.right_arm_config.port=can0 \
|
--robot.right_arm_config.port=can0 \
|
||||||
--robot.right_arm_config.side=right \
|
--robot.right_arm_config.side=right \
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||||
--teleop.type=openarm_mini \
|
--teleop.type=bi_openarm_mini \
|
||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
|
|||||||
@@ -719,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
|||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
"steps": 5000,
|
"steps": 5000,
|
||||||
"log_freq": 10,
|
"log_freq": 10,
|
||||||
"eval_freq": 1000,
|
"env_eval_freq": 1000,
|
||||||
"save_freq": 1000,
|
"save_freq": 1000,
|
||||||
"save_checkpoint": true,
|
"save_checkpoint": true,
|
||||||
"seed": 2,
|
"seed": 2,
|
||||||
|
|||||||
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
|||||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||||
|
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||||
|
|
||||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ lerobot-rollout \
|
|||||||
--strategy.num_episodes=20 \
|
--strategy.num_episodes=20 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--teleop.type=openarm_mini \
|
--teleop.type=bi_openarm_mini \
|
||||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||||
--dataset.single_task="Fold the T-shirt"
|
--dataset.single_task="Fold the T-shirt"
|
||||||
```
|
```
|
||||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
|||||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||||
| `--teleop.type` | **Required.** Teleoperator type |
|
| `--teleop.type` | **Required.** Teleoperator type |
|
||||||
|
|
||||||
|
### Episodic (`--strategy.type=episodic`)
|
||||||
|
|
||||||
|
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=episodic \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--teleop.type=so100_leader \
|
||||||
|
--teleop.port=/dev/ttyACM1 \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||||
|
--dataset.num_episodes=20 \
|
||||||
|
--dataset.episode_time_s=30 \
|
||||||
|
--dataset.reset_time_s=10 \
|
||||||
|
--dataset.single_task="Pick up the red cube"
|
||||||
|
```
|
||||||
|
|
||||||
|
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||||
|
|
||||||
|
**Keyboard controls:**
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| ----------- | -------------------------------- |
|
||||||
|
| `→` (right) | End the current episode early |
|
||||||
|
| `←` (left) | Discard episode and re-record it |
|
||||||
|
| `ESC` | Stop the recording session |
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `--dataset.num_episodes` | Number of episodes to record |
|
||||||
|
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||||
|
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||||
|
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||||
|
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||||
|
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Inference Backends
|
## Inference Backends
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
# LeLab - LeRobot Guide
|
||||||
|
|
||||||
|
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> For now LeLab is compatible only with SO-ARM101
|
||||||
|
|
||||||
|
<Youtube id="VqyKUuW9V1g" />
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
|
||||||
|
|
||||||
|
```
|
||||||
|
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||||
|
```
|
||||||
|
|
||||||
|
After install, run `lelab` from your terminal anytime to start the app.
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
|
||||||
|
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
|
||||||
|
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
|
||||||
|
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
|
||||||
|
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
|
||||||
|
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
|
||||||
|
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
|
||||||
|
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
|
||||||
@@ -275,7 +275,7 @@ A converter aggregates per‑episode files into larger shards and writes episode
|
|||||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||||
|
|
||||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||||
```
|
```
|
||||||
|
|
||||||
**What it does**
|
**What it does**
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Reproducing published results
|
## Reproducing published results
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Relationship to LIBERO
|
## Relationship to LIBERO
|
||||||
|
|||||||
@@ -120,11 +120,11 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Practical tips
|
## Practical tips
|
||||||
|
|
||||||
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
||||||
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
- Adjust `batch_size`, `steps`, and `env_eval_freq` to match your compute budget.
|
||||||
|
|||||||
@@ -0,0 +1,433 @@
|
|||||||
|
# MolmoAct2 Policy
|
||||||
|
|
||||||
|
MolmoAct2 is the LeRobot policy implementation of
|
||||||
|
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
|
||||||
|
training, evaluation, checkpointing, and dataset interfaces for easier use with
|
||||||
|
LeRobot datasets.
|
||||||
|
|
||||||
|
This implementation currently supports training and evaluation for the regular
|
||||||
|
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||||
|
not included in this LeRobot policy yet and is coming soon.
|
||||||
|
|
||||||
|
For the original MolmoAct2 training code used for the experiments reported in
|
||||||
|
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||||
|
|
||||||
|
## Installation Requirements
|
||||||
|
|
||||||
|
Install LeRobot with the MolmoAct2 optional dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[molmoact2]"
|
||||||
|
```
|
||||||
|
|
||||||
|
To run the models in this repository, you need an NVIDIA GPU. The measurements
|
||||||
|
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
|
||||||
|
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
|
||||||
|
`gradient_checkpointing=true` and include the forward pass, backward pass,
|
||||||
|
gradient clipping, optimizer step, and optimizer state allocation. Values are
|
||||||
|
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
|
||||||
|
dataloader workers, CUDA context, and fragmentation.
|
||||||
|
|
||||||
|
Multi-GPU training through `accelerate` increases throughput and global batch
|
||||||
|
size, but this LeRobot port does not currently expose the original MolmoAct2
|
||||||
|
`fsdp_devices` model-parallel training path. The current training script has
|
||||||
|
not been tested for multi-node training.
|
||||||
|
|
||||||
|
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
|
||||||
|
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
|
||||||
|
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
|
||||||
|
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
|
||||||
|
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
|
||||||
|
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
|
||||||
|
|
||||||
|
The repo has been tested with Ubuntu 22.04.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use MolmoAct2 in a LeRobot training config, set:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.type=molmoact2
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
|
||||||
|
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
|
||||||
|
the same LeRobot training loop, dataset transforms, checkpoint saving, and
|
||||||
|
logging. The difference is only how the initial policy weights and processor
|
||||||
|
state are loaded.
|
||||||
|
|
||||||
|
### Training With Original MolmoAct2 Weight
|
||||||
|
|
||||||
|
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
|
||||||
|
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
|
||||||
|
the original HF model files, then build its own policy processor from the
|
||||||
|
dataset metadata and the policy options below.
|
||||||
|
|
||||||
|
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
|
||||||
|
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
|
||||||
|
augmentation, and LeRobot's checkpointing/logging path.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch \
|
||||||
|
--num_processes=8 \
|
||||||
|
--mixed_precision=bf16 \
|
||||||
|
-m lerobot.scripts.lerobot_train \
|
||||||
|
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||||
|
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||||
|
--dataset.video_backend=pyav \
|
||||||
|
--dataset.image_transforms.enable=true \
|
||||||
|
--policy.type=molmoact2 \
|
||||||
|
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.action_mode=both \
|
||||||
|
--policy.chunk_size=10 \
|
||||||
|
--policy.n_action_steps=10 \
|
||||||
|
--policy.setup_type="single franka robotic arm in libero" \
|
||||||
|
--policy.control_mode="delta end-effector pose" \
|
||||||
|
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
|
||||||
|
--policy.model_dtype=bfloat16 \
|
||||||
|
--policy.num_flow_timesteps=8 \
|
||||||
|
--policy.gradient_checkpointing=true \
|
||||||
|
--policy.freeze_embedding=true \
|
||||||
|
--policy.normalize_gripper=false \
|
||||||
|
--policy.enable_knowledge_insulation=false \
|
||||||
|
--policy.push_to_hub=false \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--wandb.entity=<wandb_entity> \
|
||||||
|
--wandb.project=<wandb_project> \
|
||||||
|
--job_name=<job_name> \
|
||||||
|
--output_dir=outputs/<job_name> \
|
||||||
|
--steps=10000 \
|
||||||
|
--batch_size=32 \
|
||||||
|
--num_workers=4 \
|
||||||
|
--log_freq=20 \
|
||||||
|
--env_eval_freq=-1 \
|
||||||
|
--save_checkpoint=true \
|
||||||
|
--save_freq=2000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training With LeRobot MolmoAct2 Weight
|
||||||
|
|
||||||
|
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
|
||||||
|
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
|
||||||
|
restores the saved LeRobot policy config, model weights, processor, and
|
||||||
|
normalization statistics. You can still override training-time options such as
|
||||||
|
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch \
|
||||||
|
--num_processes=8 \
|
||||||
|
--mixed_precision=bf16 \
|
||||||
|
-m lerobot.scripts.lerobot_train \
|
||||||
|
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||||
|
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||||
|
--dataset.video_backend=pyav \
|
||||||
|
--dataset.image_transforms.enable=true \
|
||||||
|
--policy.path=/path/to/pretrained_model \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.action_mode=both \
|
||||||
|
--policy.chunk_size=10 \
|
||||||
|
--policy.n_action_steps=10 \
|
||||||
|
--policy.model_dtype=bfloat16 \
|
||||||
|
--policy.num_flow_timesteps=8 \
|
||||||
|
--policy.gradient_checkpointing=true \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--wandb.entity=<wandb_entity> \
|
||||||
|
--wandb.project=<wandb_project> \
|
||||||
|
--job_name=<job_name> \
|
||||||
|
--output_dir=outputs/<job_name> \
|
||||||
|
--steps=10000 \
|
||||||
|
--batch_size=32 \
|
||||||
|
--num_workers=4 \
|
||||||
|
--log_freq=20 \
|
||||||
|
--env_eval_freq=-1 \
|
||||||
|
--save_checkpoint=true \
|
||||||
|
--save_freq=2000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Practices
|
||||||
|
|
||||||
|
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
|
||||||
|
or a real-world dataset with less than 200 demonstrations, a global batch size of
|
||||||
|
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
|
||||||
|
cases, we intentionally keep the action expert fully trainable, which we found
|
||||||
|
to be crucial for model performance. For larger fine-tuning datasets, larger
|
||||||
|
global batch sizes and full fine-tuning are usually preferred.
|
||||||
|
|
||||||
|
### Common Policy Options
|
||||||
|
|
||||||
|
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
|
||||||
|
Use this for released MolmoAct2 weights.
|
||||||
|
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
|
||||||
|
created by LeRobot training.
|
||||||
|
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
|
||||||
|
`both`. `both` trains the flow-matching action expert and the discrete
|
||||||
|
action-token loss.
|
||||||
|
- `policy.train_action_expert_only`: trains only parameters whose names contain
|
||||||
|
`action_expert`. It requires `policy.action_mode=continuous`.
|
||||||
|
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
|
||||||
|
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
|
||||||
|
expert linear layers. When `policy.enable_lora_action_expert=false`, the
|
||||||
|
action expert base weights remain fully trainable while the VLM is trained
|
||||||
|
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
|
||||||
|
action expert is also adapter-tuned instead of fully fine-tuned.
|
||||||
|
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
|
||||||
|
context K/V states before the action loss. The default is `false`.
|
||||||
|
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
|
||||||
|
`10`. This LeRobot port overrides the loaded checkpoint's
|
||||||
|
`max_action_horizon` with this value.
|
||||||
|
- `policy.n_action_steps`: number of actions consumed from each predicted
|
||||||
|
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
|
||||||
|
- `policy.setup_type`: text inserted into the prompt to describe the robot and
|
||||||
|
scene, e.g. `single franka robotic arm in libero`. More examples are listed
|
||||||
|
in the `metadata_by_tag` entries of
|
||||||
|
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
|
||||||
|
- `policy.control_mode`: text inserted into the prompt to describe the action
|
||||||
|
space, e.g. `delta end-effector pose` or `absolute joint pose`.
|
||||||
|
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
|
||||||
|
processor.
|
||||||
|
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
|
||||||
|
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
|
||||||
|
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
|
||||||
|
example during training. We use `8` for fine-tuning.
|
||||||
|
- `policy.num_inference_steps`: optional override for continuous action
|
||||||
|
generation steps at inference time.
|
||||||
|
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
|
||||||
|
to reduce activation memory.
|
||||||
|
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
|
||||||
|
- `policy.normalize_gripper`: controls whether gripper dimensions are included
|
||||||
|
in state/action quantile normalization. The default is `false`.
|
||||||
|
- `policy.normalize_language`: normalizes task strings before prompt
|
||||||
|
construction. The default is `true`.
|
||||||
|
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
|
||||||
|
Released checkpoints use `policy.expected_max_action_dim=32`.
|
||||||
|
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
|
||||||
|
infer it from images, state dimension, action dimension, action horizon, and
|
||||||
|
discrete-action mode.
|
||||||
|
|
||||||
|
### Learning Rates
|
||||||
|
|
||||||
|
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
|
||||||
|
fine-tuning experiments.
|
||||||
|
|
||||||
|
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
|
||||||
|
`policy.optimizer_vit_lr=5e-6` for the vision tower,
|
||||||
|
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
|
||||||
|
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
|
||||||
|
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
|
||||||
|
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
|
||||||
|
`policy.enable_lora_action_expert=false`, so the action expert is still fully
|
||||||
|
fine-tuned with `policy.optimizer_action_expert_lr`. If
|
||||||
|
`policy.enable_lora_action_expert=true`, the action expert is trained through
|
||||||
|
LoRA adapters instead.
|
||||||
|
- Action-expert-only fine-tuning trains only the action expert and uses
|
||||||
|
`policy.optimizer_action_expert_lr=5e-5`.
|
||||||
|
|
||||||
|
You can override the full fine-tuning and action-expert learning rates with
|
||||||
|
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
|
||||||
|
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
|
||||||
|
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
|
||||||
|
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
|
||||||
|
|
||||||
|
### Dataset Quantile Statistics
|
||||||
|
|
||||||
|
MolmoAct2 defaults to quantile normalization for state and action features. If
|
||||||
|
your dataset has not been converted with quantile statistics, you can add them
|
||||||
|
with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||||
|
--repo-id=your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively, train MolmoAct2 with mean/std normalization:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
|
||||||
|
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
|
||||||
|
fixed and use `policy.per_episode_seed=true`.
|
||||||
|
|
||||||
|
**Important:** We found that `num_steps_wait=10` does not reliably let the
|
||||||
|
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
|
||||||
|
results reported here use `num_steps_wait=50`.
|
||||||
|
|
||||||
|
### Evaluation With LeRobot MolmoAct2 Weight
|
||||||
|
|
||||||
|
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
|
||||||
|
normalization statistics are restored together with the model.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MUJOCO_GL=egl
|
||||||
|
export PYOPENGL_PLATFORM=egl
|
||||||
|
export OMP_NUM_THREADS=1
|
||||||
|
export MKL_NUM_THREADS=1
|
||||||
|
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
|
||||||
|
--policy.inference_action_mode=continuous \
|
||||||
|
--policy.model_dtype=bfloat16 \
|
||||||
|
--policy.use_amp=true \
|
||||||
|
--policy.enable_inference_cuda_graph=true \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.per_episode_seed=true \
|
||||||
|
--policy.eval_seed=1000 \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
|
||||||
|
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=50 \
|
||||||
|
--seed=1000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation With Original MolmoAct2 Weight
|
||||||
|
|
||||||
|
You can evaluate a released Hugging Face checkpoint directly without first
|
||||||
|
converting it to a LeRobot checkpoint. In this case, set
|
||||||
|
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
|
||||||
|
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
|
||||||
|
normalization statistics, action horizon, prompt metadata, and image-key order
|
||||||
|
from the checkpoint's `norm_stats.json`.
|
||||||
|
|
||||||
|
To fully replicate the MolmoAct2 paper results with released Hugging Face
|
||||||
|
checkpoints, we recommend using the v0.5.1-pinned
|
||||||
|
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
|
||||||
|
branch. That branch matches the original evaluation settings used for the
|
||||||
|
reported numbers.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MUJOCO_GL=egl
|
||||||
|
export PYOPENGL_PLATFORM=egl
|
||||||
|
export OMP_NUM_THREADS=1
|
||||||
|
export MKL_NUM_THREADS=1
|
||||||
|
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.type=molmoact2 \
|
||||||
|
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||||
|
--policy.norm_tag=libero \
|
||||||
|
--policy.inference_action_mode=continuous \
|
||||||
|
--policy.model_dtype=float32 \
|
||||||
|
--policy.use_amp=false \
|
||||||
|
--policy.enable_inference_cuda_graph=true \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.per_episode_seed=true \
|
||||||
|
--policy.eval_seed=1000 \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_goal \
|
||||||
|
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=50 \
|
||||||
|
--seed=1000
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
|
||||||
|
full LIBERO suite. The same command works for other released MolmoAct2
|
||||||
|
checkpoints as long as the requested `policy.norm_tag` exists in that
|
||||||
|
checkpoint's `norm_stats.json`.
|
||||||
|
|
||||||
|
### Common Evaluation Options
|
||||||
|
|
||||||
|
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
|
||||||
|
flow-matching inference or `discrete` for action-token inference. It must be
|
||||||
|
compatible with the training-time `policy.action_mode` saved in the
|
||||||
|
checkpoint.
|
||||||
|
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
|
||||||
|
saved by LeRobot.
|
||||||
|
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
|
||||||
|
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
|
||||||
|
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
|
||||||
|
image-key order, and action horizon from the original checkpoint's
|
||||||
|
`norm_stats.json`. It is required for direct original-HF checkpoint
|
||||||
|
evaluation.
|
||||||
|
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
|
||||||
|
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
|
||||||
|
- `policy.use_amp`: runs the policy forward under autocast during eval. For
|
||||||
|
`model_dtype=bfloat16`, keep this enabled.
|
||||||
|
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
|
||||||
|
graph path for faster repeated continuous-action rollout.
|
||||||
|
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
|
||||||
|
action generation deterministic per episode for replication.
|
||||||
|
- `env.task`: comma-separated LIBERO suites or a single suite. Use
|
||||||
|
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
|
||||||
|
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
|
||||||
|
by the policy processor.
|
||||||
|
|
||||||
|
## Performance Results
|
||||||
|
|
||||||
|
### LIBERO Benchmark Results
|
||||||
|
|
||||||
|
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
|
||||||
|
compare and test its LeRobot implementation, we fine-tuned
|
||||||
|
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
|
||||||
|
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
|
||||||
|
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
|
||||||
|
results.
|
||||||
|
|
||||||
|
The LeRobot fine-tuned checkpoint reported here is available at
|
||||||
|
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
|
||||||
|
and was trained on
|
||||||
|
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
|
||||||
|
|
||||||
|
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
|
||||||
|
| -------------- | ---------------------: | -----------------: |
|
||||||
|
| LIBERO Spatial | 98.4% | 97.8% |
|
||||||
|
| LIBERO Object | 100.0% | 100.0% |
|
||||||
|
| LIBERO Goal | 98.0% | 97.8% |
|
||||||
|
| LIBERO 10 | 96.6% | 93.2% |
|
||||||
|
| Average | 98.25% | 97.20% |
|
||||||
|
|
||||||
|
These results demonstrate MolmoAct2's strong performance across diverse robotic
|
||||||
|
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
|
||||||
|
evaluation section.
|
||||||
|
|
||||||
|
## Differences From the Original Implementation
|
||||||
|
|
||||||
|
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
|
||||||
|
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
|
||||||
|
differences from the original training repository are:
|
||||||
|
|
||||||
|
- The original paper training stack loads the model in fp32 and trains under
|
||||||
|
mixed precision. This LeRobot port usually loads the checkpoint directly in
|
||||||
|
`policy.model_dtype=bfloat16` for lower memory use.
|
||||||
|
- The original repository uses its own FSDP/model-parallel training path. The
|
||||||
|
LeRobot port uses the standard LeRobot/Accelerate training path and has not
|
||||||
|
been tested for multi-node training.
|
||||||
|
- The original repository supports sequence packing. The LeRobot port trains on
|
||||||
|
one LeRobot sample per item and pads to an inferred fixed sequence budget.
|
||||||
|
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
|
||||||
|
dataset transforms, image augmentation, and Weights & Biases logging
|
||||||
|
conventions.
|
||||||
|
- The original training path supports mixed action horizons by padding to
|
||||||
|
`max_action_horizon` and masking padded horizon slots in the action expert
|
||||||
|
self-attention. This is useful when training across datasets with different
|
||||||
|
control frequencies. The LeRobot port currently targets single-dataset
|
||||||
|
fine-tuning, so `policy.chunk_size` overrides the checkpoint
|
||||||
|
`max_action_horizon` and horizon masking is not implemented yet. Support for
|
||||||
|
this mixed-horizon path is planned.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{fang2026molmoact2actionreasoningmodels,
|
||||||
|
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||||
|
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||||
|
year={2026},
|
||||||
|
eprint={2605.02881},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.RO},
|
||||||
|
url={https://arxiv.org/abs/2605.02881},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This model is licensed under Apache 2.0. It is intended for research and
|
||||||
|
educational use in accordance with
|
||||||
|
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||||
|
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||||
@@ -113,6 +113,61 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
|
|||||||
--policy=act
|
--policy=act
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Training Large Models with FSDP
|
||||||
|
|
||||||
|
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
|
||||||
|
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
|
||||||
|
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
|
||||||
|
|
||||||
|
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||||
|
--policy.type=<your_policy> \
|
||||||
|
--output_dir=outputs/train/my_policy_fsdp
|
||||||
|
```
|
||||||
|
|
||||||
|
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
distributed_type: FSDP
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 4
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 1
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
|
||||||
|
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
```
|
||||||
|
|
||||||
|
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
|
||||||
|
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
|
||||||
|
optimizer before `accelerator.prepare()`.
|
||||||
|
|
||||||
|
### FSDP checkpoints
|
||||||
|
|
||||||
|
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
|
||||||
|
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two things to look out for:
|
||||||
|
|
||||||
|
- **Checkpoints store fp32 weights.** Under mixed precision (`bf16`/`fp16`) FSDP keeps an fp32 master
|
||||||
|
copy, and the checkpoint saves it (~2× the bf16 size on disk) so training can resume consistently
|
||||||
|
with the fp32 optimizer state; `from_pretrained` casts back to the policy dtype on load. FSDP-specific
|
||||||
|
caveat: an fp32 checkpoint is materialized in full precision on the target device _before_ casting,
|
||||||
|
so loading it for inference on a tight GPU can OOM even when the bf16 model would fit — load on CPU
|
||||||
|
first, or cast `model.safetensors` to the deployment dtype offline.
|
||||||
|
- The sharded optimizer state is gathered into a full (world-size-independent) state dict and saved
|
||||||
|
alongside the model in the same `optimizer_state.safetensors` / `optimizer_param_groups.json`
|
||||||
|
format as single-GPU training, so **resume-from-checkpoint is supported** with `--resume=true`.
|
||||||
|
Resume reshards both the model and the optimizer state to the _current_ FSDP topology, so you can
|
||||||
|
resume an FSDP checkpoint on a different number of GPUs. Note that the data sampler is only
|
||||||
|
sample-exact when the world size and batch size match the original run (a warning is logged
|
||||||
|
otherwise); the optimizer/model state itself is unaffected.
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ lerobot-train \
|
|||||||
--steps=30000 \
|
--steps=30000 \
|
||||||
--save_freq=1000 \
|
--save_freq=1000 \
|
||||||
--log_freq=100 \
|
--log_freq=100 \
|
||||||
--eval_freq=1000 \
|
--env_eval_freq=1000 \
|
||||||
--policy.type=multi_task_dit \
|
--policy.type=multi_task_dit \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--policy.horizon=32 \
|
--policy.horizon=32 \
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ lerobot-train \
|
|||||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||||
--repo-id=your_dataset \
|
--repo-id=your_dataset \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
# MolmoAct2
|
||||||
|
|
||||||
|
This repository contains the LeRobot policy implementation of
|
||||||
|
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
|
||||||
|
training, evaluation, checkpointing, and dataset compatibility.
|
||||||
|
|
||||||
|
This implementation currently supports training and evaluation for the regular
|
||||||
|
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||||
|
not included in this LeRobot policy yet and is coming soon.
|
||||||
|
|
||||||
|
For the original MolmoAct2 training code used for the experiments reported in
|
||||||
|
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||||
|
|
||||||
|
## LIBERO Evaluation
|
||||||
|
|
||||||
|
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
|
||||||
|
scene stabilize and can degrade measured success. All LIBERO evaluation results
|
||||||
|
reported for this LeRobot implementation use `num_steps_wait=50`.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{fang2026molmoact2actionreasoningmodels,
|
||||||
|
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||||
|
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||||
|
year={2026},
|
||||||
|
eprint={2605.02881},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.RO},
|
||||||
|
url={https://arxiv.org/abs/2605.02881},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This model is licensed under Apache 2.0. It is intended for research and
|
||||||
|
educational use in accordance with
|
||||||
|
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||||
|
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
# VLA-JEPA
|
||||||
|
|
||||||
|
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||||
|
|
||||||
|
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
| Component | Module | Role |
|
||||||
|
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||||
|
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||||
|
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||||
|
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||||
|
|
||||||
|
At inference time only the Qwen backbone and action head are used; the world model is not needed.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||||
|
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||||
|
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||||
|
year = {2026},
|
||||||
|
eprint = {2602.10098},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
primaryClass = {cs.RO},
|
||||||
|
url = {https://arxiv.org/abs/2602.10098},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||||
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
|
|||||||
If you have existing datasets in v2.1 format, use the migration tool:
|
If you have existing datasets in v2.1 format, use the migration tool:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
|
||||||
--repo-id your_id/existing_dataset
|
--repo-id your_id/existing_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,219 +0,0 @@
|
|||||||
# Quickstart
|
|
||||||
|
|
||||||
This is the **shortest path** from an unboxed SO-101 to a policy that drives your own robot. Every step is copy-paste; replace the **`<placeholders>`** with the values for your setup.
|
|
||||||
|
|
||||||
By the end you will have:
|
|
||||||
|
|
||||||
- A calibrated SO-101 leader + follower pair.
|
|
||||||
- A dataset of 30 episodes pushed to the Hugging Face Hub.
|
|
||||||
- A trained ACT policy (~20k steps) running on your robot via `lerobot-rollout`.
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> **How long will this take?**
|
|
||||||
> Recording 30 episodes is roughly 30–60 minutes of teleoperation. Training ACT for 20k steps takes ~1.5h on an A100, a few hours on a laptop RTX 3060, longer on Apple Silicon (`mps`). The commands themselves are quick — most of the wall-clock is data collection and training.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> If you only want to **understand the codebase** or **train on an existing dataset without hardware**, this page isn't for you. Read [Core concepts](./core_concepts) first, then jump to [Imitation learning end-to-end](./il_robots).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Before you start
|
|
||||||
|
|
||||||
You need:
|
|
||||||
|
|
||||||
- An **assembled SO-101 leader + follower pair**. If your robot is not assembled yet, follow the [SO-101 assembly guide](./so101) and come back here.
|
|
||||||
- **One or two cameras** (USB webcam works fine).
|
|
||||||
- A **CUDA GPU with ≥ 6 GB VRAM** (ACT is light — a laptop RTX 3060 works). Apple Silicon (`mps`) and CPU are supported but slower. See the [compute hardware guide](./hardware_guide) for sizing.
|
|
||||||
- A **Hugging Face account** — datasets and the trained policy will be pushed to your Hub.
|
|
||||||
|
|
||||||
If any of the above is missing, fix it first; the rest of the page assumes it.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 1 — Install LeRobot
|
|
||||||
|
|
||||||
Follow the full [Installation Guide](./installation) for environment setup, then add the SO-101 motor stack and log in to the Hub:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install 'lerobot[feetech]'
|
|
||||||
git lfs install && git lfs pull
|
|
||||||
hf auth login # paste a token from https://huggingface.co/settings/tokens
|
|
||||||
```
|
|
||||||
|
|
||||||
Sanity check — the CLI entry points should be available:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-find-port --help
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 2 — Identify USB ports and motor IDs
|
|
||||||
|
|
||||||
Plug **only the follower arm** in (USB + power) and run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-find-port
|
|
||||||
```
|
|
||||||
|
|
||||||
When prompted, unplug it and press Enter. Note the printed port — that's your `<FOLLOWER_PORT>`. Repeat with only the **leader arm** plugged in to get `<LEADER_PORT>`.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> On Linux, USB ports look like `/dev/ttyACM0`; on macOS like `/dev/tty.usbmodem...`. On Linux you may need `sudo chmod 666 /dev/ttyACM0` to grant access.
|
|
||||||
|
|
||||||
If your motors are brand-new (or repurposed), set their IDs and baudrate **once per arm**:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-setup-motors --robot.type=so101_follower --robot.port=<FOLLOWER_PORT>
|
|
||||||
lerobot-setup-motors --teleop.type=so101_leader --teleop.port=<LEADER_PORT>
|
|
||||||
```
|
|
||||||
|
|
||||||
The script walks you through connecting motors one at a time. Full details: [SO-101 → Configure the motors](./so101#configure-the-motors).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 3 — Calibrate
|
|
||||||
|
|
||||||
Center every joint roughly in the middle of its range, then run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-calibrate \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower
|
|
||||||
|
|
||||||
lerobot-calibrate \
|
|
||||||
--teleop.type=so101_leader \
|
|
||||||
--teleop.port=<LEADER_PORT> \
|
|
||||||
--teleop.id=my_leader
|
|
||||||
```
|
|
||||||
|
|
||||||
After pressing Enter, sweep each joint through its full range of motion, then press Enter again to finish.
|
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> The `--robot.id` / `--teleop.id` values (`my_follower`, `my_leader`) become the **calibration keys**. Reuse the same IDs in every later command — that's how LeRobot finds the calibration on disk.
|
|
||||||
|
|
||||||
Watch the [calibration video](./so101#calibrate) if anything is unclear.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 4 — Teleoperate (sanity check, no recording)
|
|
||||||
|
|
||||||
Before recording anything, confirm the leader drives the follower correctly:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-teleoperate \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower \
|
|
||||||
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
|
||||||
--teleop.type=so101_leader \
|
|
||||||
--teleop.port=<LEADER_PORT> \
|
|
||||||
--teleop.id=my_leader \
|
|
||||||
--display_data=true
|
|
||||||
```
|
|
||||||
|
|
||||||
A Rerun window should open showing the camera feed and joint angles. Move the leader — the follower should mirror it in real time. If it doesn't, see [Troubleshooting & FAQ](./troubleshooting).
|
|
||||||
|
|
||||||
Don't know which camera index is which? Run `lerobot-find-cameras` — it saves a frame from each detected camera so you can pick the right one.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 5 — Record a dataset (30 episodes)
|
|
||||||
|
|
||||||
Now record demonstrations. Pick a short, repeatable task (e.g. *"put the red brick in the bowl"*). The dataset is pushed to the Hub under your username:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export HF_USER=<your-hf-username>
|
|
||||||
|
|
||||||
lerobot-record \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower \
|
|
||||||
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30} }" \
|
|
||||||
--teleop.type=so101_leader \
|
|
||||||
--teleop.port=<LEADER_PORT> \
|
|
||||||
--teleop.id=my_leader \
|
|
||||||
--dataset.repo_id=${HF_USER}/so101_quickstart \
|
|
||||||
--dataset.num_episodes=30 \
|
|
||||||
--dataset.single_task="Put the red brick in the bowl" \
|
|
||||||
--dataset.streaming_encoding=true \
|
|
||||||
--display_data=true
|
|
||||||
```
|
|
||||||
|
|
||||||
**Keyboard controls during recording:**
|
|
||||||
|
|
||||||
- **`→` (Right Arrow)** — save the current episode and move to the next.
|
|
||||||
- **`←` (Left Arrow)** — discard the current episode and retry.
|
|
||||||
- **`Esc`** — stop, encode videos, and upload to the Hub.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> **Quality beats quantity.** 30 clean, varied episodes (different brick positions, lighting, camera shake) train a much better policy than 100 identical ones. Move the object around. Vary your speed slightly.
|
|
||||||
|
|
||||||
When you're done, your dataset lives at `https://huggingface.co/datasets/${HF_USER}/so101_quickstart`. You can preview it in the browser. For deeper recording options (resume, multiple tasks, custom processors), see [Imitation learning end-to-end → Record](./il_robots#record-a-dataset).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 6 — Train ACT
|
|
||||||
|
|
||||||
ACT (Action Chunking Transformer) is the right default for a first run — small, fast, and works well on 30 episodes.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-train \
|
|
||||||
--dataset.repo_id=${HF_USER}/so101_quickstart \
|
|
||||||
--policy.type=act \
|
|
||||||
--output_dir=outputs/train/act_so101_quickstart \
|
|
||||||
--job_name=act_so101_quickstart \
|
|
||||||
--policy.device=cuda \
|
|
||||||
--policy.repo_id=${HF_USER}/act_so101_quickstart \
|
|
||||||
--steps=20000 \
|
|
||||||
--wandb.enable=true
|
|
||||||
```
|
|
||||||
|
|
||||||
A few notes:
|
|
||||||
|
|
||||||
- Replace `--policy.device=cuda` with `mps` on Apple Silicon, or `cpu` if you have no GPU (very slow — not recommended for a real run).
|
|
||||||
- `--wandb.enable=true` is optional. If you use it, run `wandb login` first. Otherwise drop the flag.
|
|
||||||
- Checkpoints land in `outputs/train/act_so101_quickstart/checkpoints/`. The final model is also pushed to the Hub at the `--policy.repo_id` you specified.
|
|
||||||
- To resume from an interruption: `lerobot-train --config_path=outputs/train/act_so101_quickstart/checkpoints/last/pretrained_model/train_config.json --resume=true`.
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> **No GPU locally?** Train on Google Colab using the [ACT notebook](./notebooks#training-act), or rent a GPU via [Hugging Face Jobs](./il_robots#train-using-hugging-face-jobs) — pay-as-you-go, no setup.
|
|
||||||
|
|
||||||
For why ACT is the default and when to switch to SmolVLA, Pi0, or another policy, see [Choosing a policy](./policies_overview).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 7 — Run your policy on the robot
|
|
||||||
|
|
||||||
Deploy with `lerobot-rollout`. **Use the same camera layout you used while recording** — keys and resolutions must match.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-rollout \
|
|
||||||
--strategy.type=base \
|
|
||||||
--policy.path=${HF_USER}/act_so101_quickstart \
|
|
||||||
--robot.type=so101_follower \
|
|
||||||
--robot.port=<FOLLOWER_PORT> \
|
|
||||||
--robot.id=my_follower \
|
|
||||||
--robot.cameras="{ top: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30} }" \
|
|
||||||
--task="Put the red brick in the bowl" \
|
|
||||||
--duration=60
|
|
||||||
```
|
|
||||||
|
|
||||||
`--duration` is in seconds — leave it off to run until you stop the script. You should see the follower arm move on its own, attempting the task.
|
|
||||||
|
|
||||||
If observations from the robot use different keys than the policy expects, you'll need a [rename map](./rename_map). If latency matters, look at [async inference](./async) and [real-time chunking](./rtc).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## You're done 🎉
|
|
||||||
|
|
||||||
You now have a working IL pipeline end-to-end. From here, the natural next steps are:
|
|
||||||
|
|
||||||
- **Improve the policy** — record more diverse episodes, train longer, or try a stronger model. See [Choosing a policy](./policies_overview).
|
|
||||||
- **Go deeper on imitation learning** — [Imitation learning end-to-end](./il_robots) covers multi-camera setups, multi-task datasets, episode replay, evaluation, and Hugging Face Jobs.
|
|
||||||
- **Try RL with a human in the loop** — [HIL-SERL](./hilserl) trains a policy that improves while you correct it.
|
|
||||||
- **Use a different robot** — see [Supported robots](./so101) for low-cost arms, mobile platforms, bimanual, and humanoid.
|
|
||||||
- **Build something new** — [Bring your own hardware](./integrate_hardware) and [Add a new policy](./bring_your_own_policies).
|
|
||||||
|
|
||||||
Stuck on something? Check [Troubleshooting & FAQ](./troubleshooting), or ask on [Discord](https://discord.gg/s3KuuzsPFb).
|
|
||||||
@@ -166,7 +166,7 @@ lerobot-train \
|
|||||||
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
||||||
--steps=100000 \
|
--steps=100000 \
|
||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval_freq=5000 \
|
--env_eval_freq=5000 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=5 \
|
--eval.n_episodes=5 \
|
||||||
--save_freq=10000
|
--save_freq=10000
|
||||||
|
|||||||
@@ -0,0 +1,185 @@
|
|||||||
|
# ROBOMETER
|
||||||
|
|
||||||
|
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
|
||||||
|
|
||||||
|
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
|
||||||
|
**Project**: [robometer.github.io](https://robometer.github.io/)
|
||||||
|
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
|
||||||
|
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
|
||||||
|
|
||||||
|
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
|
||||||
|
- **Success head**: predicts per-frame task success probability.
|
||||||
|
- **Preference head**: predicts which of two trajectories better completes the task during training.
|
||||||
|
|
||||||
|
The paper trains ROBOMETER with a composite objective:
|
||||||
|
|
||||||
|
```text
|
||||||
|
L = L_pref + L_prog + L_succ
|
||||||
|
```
|
||||||
|
|
||||||
|
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
|
||||||
|
|
||||||
|
## What the LeRobot Integration Covers
|
||||||
|
|
||||||
|
- Standard `reward_model.type=robometer` configuration through LeRobot.
|
||||||
|
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
|
||||||
|
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
|
||||||
|
- Dense, frame-level progress and success predictions internally.
|
||||||
|
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
|
||||||
|
|
||||||
|
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
|
||||||
|
|
||||||
|
## Installation Requirements
|
||||||
|
|
||||||
|
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||||
|
2. Install the ROBOMETER dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[robometer]"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you use `uv` directly from a source checkout:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv sync --extra robometer
|
||||||
|
```
|
||||||
|
|
||||||
|
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
|
||||||
|
|
||||||
|
## Model Inputs and Outputs
|
||||||
|
|
||||||
|
ROBOMETER expects:
|
||||||
|
|
||||||
|
- A trajectory video or sequence of frames.
|
||||||
|
- A natural-language task description.
|
||||||
|
|
||||||
|
In LeRobot datasets, the preprocessor reads:
|
||||||
|
|
||||||
|
| Config field | Default | Meaning |
|
||||||
|
| ------------------------- | ------------------------ | ----------------------------------------------------- |
|
||||||
|
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
|
||||||
|
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
|
||||||
|
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
|
||||||
|
|
||||||
|
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
|
||||||
|
|
||||||
|
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
|
||||||
|
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Load the Reward Model Directly
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||||
|
|
||||||
|
cfg = RobometerConfig(
|
||||||
|
pretrained_path="lerobot/Robometer-4B",
|
||||||
|
device="cuda",
|
||||||
|
reward_output="progress",
|
||||||
|
)
|
||||||
|
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Encode Frames and Compute a Reward
|
||||||
|
|
||||||
|
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||||
|
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||||
|
|
||||||
|
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
|
||||||
|
# task: str
|
||||||
|
encoder = RobometerEncoderProcessorStep(
|
||||||
|
base_model_id=cfg.base_model_id,
|
||||||
|
use_multi_image=cfg.use_multi_image,
|
||||||
|
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
|
||||||
|
max_frames=cfg.max_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = encoder.encode_samples([(frames, task)])
|
||||||
|
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
|
||||||
|
|
||||||
|
reward = reward_model.compute_reward(batch)
|
||||||
|
```
|
||||||
|
|
||||||
|
`reward` is a tensor of shape `(batch_size,)`.
|
||||||
|
|
||||||
|
### Use the Reward Factory
|
||||||
|
|
||||||
|
You can also instantiate ROBOMETER through the reward factory:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||||
|
|
||||||
|
cfg = make_reward_model_config(
|
||||||
|
"robometer",
|
||||||
|
pretrained_path="lerobot/Robometer-4B",
|
||||||
|
device="cuda",
|
||||||
|
image_key="observation.images.top",
|
||||||
|
)
|
||||||
|
reward_model = make_reward_model(cfg)
|
||||||
|
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||||
|
```
|
||||||
|
|
||||||
|
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
|
||||||
|
|
||||||
|
## Configuration Notes
|
||||||
|
|
||||||
|
### Backbone and Vocabulary
|
||||||
|
|
||||||
|
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
|
||||||
|
|
||||||
|
```text
|
||||||
|
<|split_token|>
|
||||||
|
<|reward_token|>
|
||||||
|
<|pref_token|>
|
||||||
|
<|sim_token|>
|
||||||
|
<|prog_token|>
|
||||||
|
```
|
||||||
|
|
||||||
|
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
|
||||||
|
|
||||||
|
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
|
||||||
|
|
||||||
|
### Progress Prediction
|
||||||
|
|
||||||
|
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
|
||||||
|
|
||||||
|
### Success Prediction
|
||||||
|
|
||||||
|
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
|
||||||
|
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
|
||||||
|
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [ROBOMETER project](https://robometer.github.io/)
|
||||||
|
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
|
||||||
|
- [Original ROBOMETER code](https://github.com/robometer/robometer)
|
||||||
|
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
|
||||||
|
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{liang2026robometer,
|
||||||
|
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
|
||||||
|
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
|
||||||
|
year={2026},
|
||||||
|
booktitle={Robotics: Science and Systems 2026},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
|
||||||
@@ -0,0 +1,177 @@
|
|||||||
|
# TOPReward
|
||||||
|
|
||||||
|
TOPReward is a **zero-shot reward model** that extracts token log-probabilities from an off-the-shelf vision-language model (VLM) as a robotic reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood that the instruction is true — no fine-tuning required.
|
||||||
|
|
||||||
|
**Paper**: [TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics](https://arxiv.org/abs/2602.19313)
|
||||||
|
**Project**: [topreward.github.io](https://topreward.github.io/webpage/)
|
||||||
|
**Original code**: [github.com/TOPReward/TOPReward](https://github.com/TOPReward/TOPReward)
|
||||||
|
**Default backbone**: [Qwen/Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
TOPReward asks a generic VLM how likely a task instruction is, **conditioned on the video** of a robot trying to complete that task. Concretely, given:
|
||||||
|
|
||||||
|
- A trajectory video (a sequence of frames).
|
||||||
|
- A task instruction (e.g. _"open the drawer"_).
|
||||||
|
|
||||||
|
it builds a chat prompt of the form
|
||||||
|
|
||||||
|
```text
|
||||||
|
<video>
|
||||||
|
"The above video shows a robot manipulation trajectory that completes the
|
||||||
|
following task: <instruction> Decide whether the above statement is True
|
||||||
|
or not. The answer is: True"
|
||||||
|
```
|
||||||
|
|
||||||
|
forwards it through the VLM, label-masks everything except the very last token, and reads back the log-probability of that token — by default the literal `"True"` that closes the suffix template. The resulting `log P("True" | video + prompt + instruction)` is the reward.
|
||||||
|
|
||||||
|
Because the method only depends on a frozen VLM, TOPReward is **zero-shot**: there are no fine-tuned weights to host. The "model" in LeRobot is a small wrapper around `transformers`' `Qwen3VLForConditionalGeneration` plus the label-masking logic. The processor owns the tokeniser and builds the full chat prompt (EO-1/Robometer pattern).
|
||||||
|
|
||||||
|
## What the LeRobot integration covers
|
||||||
|
|
||||||
|
- Standard `reward_model.type=topreward` configuration through LeRobot.
|
||||||
|
- VLM loading via the `transformers` `Qwen3VLForConditionalGeneration` API.
|
||||||
|
- Prompt assembly + tokenisation in the processor (matching upstream `QwenClient.compute_instruction_reward`).
|
||||||
|
- `compute_reward()` returns one scalar log-prob per sample.
|
||||||
|
- LeRobot reward-model save/load — `save_pretrained` writes only `config.json` (the VLM is identified by `vlm_name`).
|
||||||
|
- An offline labeling script that writes a `topreward_progress.parquet` (SARM-compatible schema) for RA-BC and overlay.
|
||||||
|
|
||||||
|
The current LeRobot port supports the **Qwen3-VL client only**. Other upstream clients (Gemini, OpenAI, Gemma, Molmo) can be added as follow-up extras.
|
||||||
|
|
||||||
|
## Installation Requirements
|
||||||
|
|
||||||
|
1. Install LeRobot following the [Installation Guide](./installation).
|
||||||
|
2. Install the TOPReward optional extra:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[topreward]"
|
||||||
|
```
|
||||||
|
|
||||||
|
or, with `uv` from a source checkout:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv sync --extra topreward
|
||||||
|
```
|
||||||
|
|
||||||
|
This pulls in `transformers`. The first time you run TOPReward, Hugging Face will also download the VLM weights from the Hub (~16 GB for Qwen3-VL-8B-Instruct). A GPU is strongly recommended.
|
||||||
|
|
||||||
|
## Model Inputs and Outputs
|
||||||
|
|
||||||
|
TOPReward expects:
|
||||||
|
|
||||||
|
- A trajectory video or sequence of frames.
|
||||||
|
- A natural-language task description.
|
||||||
|
|
||||||
|
In LeRobot datasets the preprocessor reads:
|
||||||
|
|
||||||
|
| Config field | Default | Meaning |
|
||||||
|
| ------------------------- | --------------------------- | --------------------------------------------- |
|
||||||
|
| `reward_model.image_key` | `observation.images.top` | Camera observation used by TOPReward |
|
||||||
|
| `reward_model.task_key` | `task` | Key in complementary data for the task string |
|
||||||
|
| `reward_model.max_frames` | `16` | Cap on frames per sample |
|
||||||
|
| `reward_model.fps` | `2.0` | Metadata passed to the Qwen video processor |
|
||||||
|
| `reward_model.vlm_name` | `Qwen/Qwen3-VL-8B-Instruct` | Hugging Face Hub id of the underlying VLM |
|
||||||
|
|
||||||
|
The model returns:
|
||||||
|
|
||||||
|
- `compute_reward(batch)`: one log-probability per sample. Higher = better task-video alignment. When `success_threshold` is finite, returns the binary thresholded value instead.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Load the reward model directly
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
|
||||||
|
|
||||||
|
cfg = TOPRewardConfig(
|
||||||
|
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
reward_model = TOPRewardModel(cfg)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use the reward factory
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||||
|
|
||||||
|
cfg = make_reward_model_config(
|
||||||
|
"topreward",
|
||||||
|
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||||
|
device="cuda",
|
||||||
|
image_key="observation.images.top",
|
||||||
|
)
|
||||||
|
reward_model = make_reward_model(cfg)
|
||||||
|
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||||
|
```
|
||||||
|
|
||||||
|
The preprocessor tokenises the full prompt (video + prefix + instruction suffix), writes Qwen-VL tensors + `prompt_length` under `observation.topreward.*`. The model reads those tensors, label-masks based on `prompt_length`, and extracts the log-prob reward.
|
||||||
|
|
||||||
|
### Offline dataset labeling
|
||||||
|
|
||||||
|
Write a `topreward_progress.parquet` for RA-BC training and overlay videos:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Sparse-dense (15 anchors per episode, matches upstream)
|
||||||
|
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
|
||||||
|
--dataset-repo-id lerobot/libero_10_image \
|
||||||
|
--num-samples 15 \
|
||||||
|
--device cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
Then render the progress overlay for any episode:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run examples/dataset/create_progress_videos.py \
|
||||||
|
--repo-id lerobot/libero_10_image \
|
||||||
|
--episode 0 \
|
||||||
|
--progress-file topreward_progress.parquet \
|
||||||
|
--gif
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Notes
|
||||||
|
|
||||||
|
### Prompt knobs
|
||||||
|
|
||||||
|
The default prompt mirrors the upstream paper:
|
||||||
|
|
||||||
|
```text
|
||||||
|
prompt_prefix = "The above video shows a robot manipulation trajectory that completes the following task: "
|
||||||
|
prompt_suffix_template = "{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||||
|
```
|
||||||
|
|
||||||
|
Both are exposed on `TOPRewardConfig` for ablation. The suffix template **must** contain `{instruction}`.
|
||||||
|
|
||||||
|
### Chat template
|
||||||
|
|
||||||
|
`add_chat_template=True` wraps the full prompt (including instruction) with the tokenizer's chat template before tokenisation. Default is `False`, matching the upstream paper's main experiments.
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- The current LeRobot port is **inference-only and zero-shot**; `forward()` is not overridden and `is_trainable` returns `False`.
|
||||||
|
- Only the **Qwen3-VL family** is supported; other upstream clients are out of scope.
|
||||||
|
- TOPReward inherits the underlying VLM's biases.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [TOPReward project page](https://topreward.github.io/webpage/)
|
||||||
|
- [TOPReward paper](https://arxiv.org/abs/2602.19313)
|
||||||
|
- [Original TOPReward code](https://github.com/TOPReward/TOPReward)
|
||||||
|
- [Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{chen2026topreward,
|
||||||
|
title={TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics},
|
||||||
|
author={Chen, Shirui and Harrison, Cole and Lee, Ying-Chun and Yang, Angela Jin and
|
||||||
|
Ren, Zhongzheng and Ratliff, Lillian J and Duan, Jiafei and Fox, Dieter and
|
||||||
|
Krishna, Ranjay},
|
||||||
|
journal={arXiv preprint arXiv:2602.19313},
|
||||||
|
year={2026}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
The original TOPReward codebase is MIT-licensed. The LeRobot port follows the LeRobot Apache 2.0 license; the wrapped Qwen3-VL weights are subject to the original Qwen license.
|
||||||
@@ -0,0 +1,235 @@
|
|||||||
|
# VLA-JEPA
|
||||||
|
|
||||||
|
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
VLA-JEPA has three main components:
|
||||||
|
|
||||||
|
| Component | Module | Role |
|
||||||
|
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||||
|
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||||
|
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||||
|
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||||
|
|
||||||
|
### Data flow
|
||||||
|
|
||||||
|
**Training:**
|
||||||
|
|
||||||
|
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||||
|
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||||
|
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||||
|
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||||
|
|
||||||
|
**Inference:**
|
||||||
|
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||||
|
|
||||||
|
### Action head details
|
||||||
|
|
||||||
|
Available presets via `action_model_type`:
|
||||||
|
|
||||||
|
| Preset | Hidden dim | Heads | Head dim |
|
||||||
|
| ------- | ---------- | ----- | -------- |
|
||||||
|
| `DiT-B` | 768 | 12 | 64 |
|
||||||
|
| `DiT-L` | 1536 | 32 | 48 |
|
||||||
|
|
||||||
|
### World model details
|
||||||
|
|
||||||
|
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||||
|
|
||||||
|
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||||
|
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||||
|
|
||||||
|
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Pretrained Checkpoints
|
||||||
|
|
||||||
|
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||||
|
|
||||||
|
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||||
|
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||||
|
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||||
|
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||||
|
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||||
|
|
||||||
|
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Key parameters in `VLAJEPAConfig`:
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||||
|
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||||
|
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||||
|
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||||
|
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||||
|
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||||
|
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||||
|
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||||
|
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
|
||||||
|
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
|
||||||
|
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
|
||||||
|
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||||
|
|
||||||
|
### Full training from scratch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
policy.type=vla_jepa \
|
||||||
|
policy.repo_id=your_org/your_repo \
|
||||||
|
dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fine-tuning from a pretrained checkpoint
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--policy.freeze_qwen=true \
|
||||||
|
--dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fine-tuning on a different embodiment
|
||||||
|
|
||||||
|
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||||
|
|
||||||
|
The layers that depend on `action_dim` and `state_dim` are:
|
||||||
|
|
||||||
|
| Layer | Key prefix |
|
||||||
|
| ----------------------------------------- | ----------------------------------- |
|
||||||
|
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||||
|
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||||
|
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--policy.freeze_qwen=true \
|
||||||
|
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||||
|
--dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||||
|
|
||||||
|
### Reproducing the LIBERO results
|
||||||
|
|
||||||
|
**Training on LIBERO:**
|
||||||
|
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||||
|
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||||
|
--steps=30000
|
||||||
|
```
|
||||||
|
|
||||||
|
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||||
|
--eval.n_episodes=10 \
|
||||||
|
--eval.batch_size=5
|
||||||
|
```
|
||||||
|
|
||||||
|
To evaluate a subset of tasks only:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_10 \
|
||||||
|
--env.task_ids='[0,1,2]' \
|
||||||
|
--eval.n_episodes=10 \
|
||||||
|
--eval.batch_size=5
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected results:**
|
||||||
|
|
||||||
|
| Suite | Episodes | Successes | Success Rate |
|
||||||
|
| -------------- | -------- | --------- | ------------ |
|
||||||
|
| libero_spatial | 100 | 93 | **95.0%** |
|
||||||
|
| libero_object | 100 | 100 | **100.0%** |
|
||||||
|
| libero_goal | 100 | 98 | **98.0%** |
|
||||||
|
| libero_10 | 100 | 96 | **93.0%** |
|
||||||
|
| **Overall** | **400** | **387** | **96.5%** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Fine-tuning on datasets with a different number of cameras
|
||||||
|
|
||||||
|
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||||
|
|
||||||
|
**Default behaviour — view padding / trimming (no action required)**
|
||||||
|
|
||||||
|
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||||
|
|
||||||
|
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||||
|
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||||
|
|
||||||
|
**Option 1 — Disable the world model**
|
||||||
|
|
||||||
|
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.enable_world_model=false \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--dataset.repo_id=your_org/single_camera_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option 2 — Reinitialize the predictor input projection**
|
||||||
|
|
||||||
|
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||||
|
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||||
|
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||||
|
year = {2026},
|
||||||
|
eprint = {2602.10098},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
primaryClass = {cs.RO},
|
||||||
|
url = {https://arxiv.org/abs/2602.10098},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||||
@@ -165,7 +165,7 @@ lerobot-train \
|
|||||||
--output_dir=./outputs/smolvla_vlabench_primitive \
|
--output_dir=./outputs/smolvla_vlabench_primitive \
|
||||||
--steps=100000 \
|
--steps=100000 \
|
||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval_freq=5000 \
|
--env_eval_freq=5000 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--save_freq=10000
|
--save_freq=10000
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
|
||||||
|
|
||||||
|
Spawns one single-GPU ``h200`` job that:
|
||||||
|
|
||||||
|
1. installs ``lerobot`` from ``main`` plus the annotation extras,
|
||||||
|
2. boots one vllm server with Qwen3.6-27B (dense VLM),
|
||||||
|
3. runs the plan / interjections / vqa modules across the dataset
|
||||||
|
in free-form mode (each episode generates its own subtasks +
|
||||||
|
memory),
|
||||||
|
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
|
||||||
|
or back to ``--repo_id``.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||||
|
|
||||||
|
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
|
||||||
|
run. For larger datasets, scale to ``h200x4`` and raise
|
||||||
|
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import get_token, run_job
|
||||||
|
|
||||||
|
token = os.environ.get("HF_TOKEN") or get_token()
|
||||||
|
if not token:
|
||||||
|
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||||
|
|
||||||
|
CMD = (
|
||||||
|
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||||
|
"pip install --no-deps "
|
||||||
|
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
|
||||||
|
"pip install --upgrade-strategy only-if-needed "
|
||||||
|
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||||
|
"openai && "
|
||||||
|
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||||
|
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||||
|
"lerobot-annotate "
|
||||||
|
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
|
||||||
|
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated "
|
||||||
|
"--push_to_hub=true "
|
||||||
|
"--vlm.backend=openai "
|
||||||
|
"--vlm.model_id=Qwen/Qwen3.6-27B "
|
||||||
|
"--vlm.num_gpus=1 "
|
||||||
|
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
|
||||||
|
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||||
|
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||||
|
"--vlm.serve_ready_timeout_s=1800 "
|
||||||
|
# Qwen3.6 ships with thinking on; annotation wants plain JSON answers.
|
||||||
|
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
job = run_job(
|
||||||
|
image="vllm/vllm-openai:latest",
|
||||||
|
command=["bash", "-c", CMD],
|
||||||
|
flavor="h200",
|
||||||
|
secrets={"HF_TOKEN": token},
|
||||||
|
timeout="2h",
|
||||||
|
)
|
||||||
|
print(f"Job URL: {job.url}")
|
||||||
|
print(f"Job ID: {job.id}")
|
||||||
+44
-11
@@ -115,8 +115,8 @@ dataset = [
|
|||||||
]
|
]
|
||||||
training = [
|
training = [
|
||||||
"lerobot[dataset]",
|
"lerobot[dataset]",
|
||||||
"accelerate>=1.10.0,<2.0.0",
|
"wandb>=0.24.0,<0.28.0",
|
||||||
"wandb>=0.24.0,<0.25.0",
|
"lerobot[accelerate-dep]",
|
||||||
]
|
]
|
||||||
hardware = [
|
hardware = [
|
||||||
"lerobot[pynput-dep]",
|
"lerobot[pynput-dep]",
|
||||||
@@ -142,7 +142,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
|||||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
|
||||||
|
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||||
@@ -177,7 +178,12 @@ unitree_g1 = [
|
|||||||
"lerobot[matplotlib-dep]",
|
"lerobot[matplotlib-dep]",
|
||||||
"lerobot[pygame-dep]",
|
"lerobot[pygame-dep]",
|
||||||
]
|
]
|
||||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
|
||||||
|
reachy2 = [
|
||||||
|
"reachy2_sdk>=1.0.15,<1.1.0",
|
||||||
|
"grpcio<=1.73.1",
|
||||||
|
"protobuf<=6.32.0",
|
||||||
|
]
|
||||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||||
@@ -198,7 +204,8 @@ wallx = [
|
|||||||
"lerobot[qwen-vl-utils-dep]",
|
"lerobot[qwen-vl-utils-dep]",
|
||||||
]
|
]
|
||||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||||
|
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||||
groot = [
|
groot = [
|
||||||
"lerobot[transformers-dep]",
|
"lerobot[transformers-dep]",
|
||||||
@@ -211,25 +218,43 @@ groot = [
|
|||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
|
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||||
|
topreward = ["lerobot[transformers-dep]"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||||
|
|
||||||
|
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
|
||||||
|
# which talks to any OpenAI-compatible server (``vllm serve`` /
|
||||||
|
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
|
||||||
|
# (see examples/annotations/run_hf_job.py).
|
||||||
|
annotations = [
|
||||||
|
"lerobot[dataset]",
|
||||||
|
"lerobot[transformers-dep]",
|
||||||
|
"openai>=1.40,<2.0",
|
||||||
|
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
|
||||||
|
# uv's single unified lock would then cap ``torch`` for every extra
|
||||||
|
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
|
||||||
|
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
|
||||||
|
# install it locally only if you run your own ``vllm serve``.
|
||||||
|
]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||||
|
|
||||||
# Simulation
|
# Simulation
|
||||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||||
@@ -274,10 +299,12 @@ all = [
|
|||||||
"lerobot[multi_task_dit]",
|
"lerobot[multi_task_dit]",
|
||||||
"lerobot[wallx]",
|
"lerobot[wallx]",
|
||||||
"lerobot[pi]",
|
"lerobot[pi]",
|
||||||
|
"lerobot[molmoact2]",
|
||||||
"lerobot[smolvla]",
|
"lerobot[smolvla]",
|
||||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||||
"lerobot[xvla]",
|
"lerobot[xvla]",
|
||||||
"lerobot[hilserl]",
|
"lerobot[hilserl]",
|
||||||
|
"lerobot[vla_jepa]",
|
||||||
"lerobot[async]",
|
"lerobot[async]",
|
||||||
"lerobot[dev]",
|
"lerobot[dev]",
|
||||||
"lerobot[test]",
|
"lerobot[test]",
|
||||||
@@ -288,6 +315,8 @@ all = [
|
|||||||
"lerobot[libero]; sys_platform == 'linux'",
|
"lerobot[libero]; sys_platform == 'linux'",
|
||||||
"lerobot[metaworld]",
|
"lerobot[metaworld]",
|
||||||
"lerobot[sarm]",
|
"lerobot[sarm]",
|
||||||
|
"lerobot[robometer]",
|
||||||
|
"lerobot[topreward]",
|
||||||
"lerobot[peft]",
|
"lerobot[peft]",
|
||||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||||
]
|
]
|
||||||
@@ -309,6 +338,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
|||||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
|
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
@@ -327,7 +357,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
|||||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
lerobot = ["envs/*.json"]
|
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
@@ -403,8 +433,11 @@ default.extend-ignore-identifiers-re = [
|
|||||||
"ein",
|
"ein",
|
||||||
"thw",
|
"thw",
|
||||||
"inpt",
|
"inpt",
|
||||||
|
"arange",
|
||||||
|
"is_compileable",
|
||||||
"ROBOTIS",
|
"ROBOTIS",
|
||||||
"OT_VALUE"
|
"OT_VALUE",
|
||||||
|
"VanderBilt"
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: Uncomment when ready to use
|
# TODO: Uncomment when ready to use
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||||
|
``language_events`` columns for LeRobot datasets.
|
||||||
|
|
||||||
|
The pipeline is decomposed into three independently runnable modules whose
|
||||||
|
outputs are staged per-episode before a final parquet rewrite:
|
||||||
|
|
||||||
|
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||||
|
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||||
|
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .config import AnnotationPipelineConfig
|
||||||
|
from .validator import StagingValidator, ValidationReport
|
||||||
|
from .writer import LanguageColumnsWriter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnnotationPipelineConfig",
|
||||||
|
"LanguageColumnsWriter",
|
||||||
|
"StagingValidator",
|
||||||
|
"ValidationReport",
|
||||||
|
]
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlanConfig:
|
||||||
|
"""``plan`` module: subtasks + plan + memory + task augmentation."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
|
||||||
|
n_task_rephrasings: int = 10
|
||||||
|
|
||||||
|
# Derive the task from video instead of episode_task: off / if_short / always.
|
||||||
|
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
|
||||||
|
derive_task_from_video: str = "if_short"
|
||||||
|
derive_task_min_words: int = 3
|
||||||
|
|
||||||
|
# --- Frame input: timestamped contact sheets (always on) ---------------
|
||||||
|
# The subtask describe/segment passes ALWAYS render the episode as
|
||||||
|
# macrodata/refiner-style contact sheets: sampled frames packed into JPEG
|
||||||
|
# grids with each frame's timestamp burned into its corner, so the VLM
|
||||||
|
# cites the exact source time of a boundary directly. This is far cheaper
|
||||||
|
# in vision tokens than one image per frame (≈2× faster subtask generation
|
||||||
|
# in practice), which is why the sampling is dense by default.
|
||||||
|
#
|
||||||
|
# ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s.
|
||||||
|
frames_per_second: float = 2.0
|
||||||
|
# Frame budget per VLM call (= columns × rows × sheets). When a whole
|
||||||
|
# episode sampled at ``frames_per_second`` exceeds this, the episode is
|
||||||
|
# AUTOMATICALLY split into consecutive windows of
|
||||||
|
# ``max_frames_per_prompt`` frames each (one describe→segment call per
|
||||||
|
# window, still at the full ``frames_per_second`` density), and the
|
||||||
|
# per-window spans are merged + stitched into one contiguous cover. So an
|
||||||
|
# episode of any length is always covered at the full sampling density.
|
||||||
|
max_frames_per_prompt: int = 60
|
||||||
|
contact_sheet_columns: int = 5
|
||||||
|
contact_sheet_frames_per_sheet: int = 20
|
||||||
|
contact_sheet_frame_width: int = 224
|
||||||
|
contact_sheet_quality: int = 84
|
||||||
|
|
||||||
|
min_subtask_seconds: float = 1.5
|
||||||
|
plan_max_steps: int = 8
|
||||||
|
|
||||||
|
# Narrate-only grounding pass before segmenting — best defense against subtasks
|
||||||
|
# invented from the task text (+1 VLM call/episode).
|
||||||
|
subtask_describe_first: bool = True
|
||||||
|
|
||||||
|
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
|
||||||
|
emit_plan: bool = True
|
||||||
|
|
||||||
|
# Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only.
|
||||||
|
# Symmetric counterpart of ``emit_plan``.
|
||||||
|
emit_memory: bool = True
|
||||||
|
|
||||||
|
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
|
||||||
|
|
||||||
|
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
|
||||||
|
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskAugAxesConfig:
|
||||||
|
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
|
||||||
|
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
|
||||||
|
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
|
||||||
|
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
synonym_paraphrase: int = 3
|
||||||
|
omit_arm: int = 3
|
||||||
|
omit_orientation: int = 2
|
||||||
|
omit_grasp_method: int = 2
|
||||||
|
combined_omissions: int = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterjectionsConfig:
|
||||||
|
"""``interjections`` module: interjections + paired speech."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
|
||||||
|
max_interjections_per_episode: int = 3
|
||||||
|
interjection_min_t: float = 2.0
|
||||||
|
|
||||||
|
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
|
||||||
|
interjection_window_seconds: float = 2.0
|
||||||
|
interjection_window_frames: int = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VqaConfig:
|
||||||
|
"""``vqa`` module: general VQA."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
vqa_emission_hz: float = 1.0
|
||||||
|
K: int = 1
|
||||||
|
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
|
||||||
|
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
|
||||||
|
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||||
|
|
||||||
|
# True: ground VQA only on --vlm.camera_key (default: every camera).
|
||||||
|
restrict_to_default_camera: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VlmConfig:
|
||||||
|
"""Shared Qwen-VL client configuration."""
|
||||||
|
|
||||||
|
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
|
||||||
|
# auto_serve=True); ``stub`` is for tests.
|
||||||
|
backend: str = "openai"
|
||||||
|
model_id: str = "Qwen/Qwen3.6-27B"
|
||||||
|
|
||||||
|
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
|
||||||
|
api_base: str = "http://localhost:8000/v1"
|
||||||
|
api_key: str = "EMPTY"
|
||||||
|
|
||||||
|
# Spawn a server if none answers api_base; False = fail fast on a remote.
|
||||||
|
auto_serve: bool = True
|
||||||
|
serve_port: int = 8000
|
||||||
|
# Override the auto-serve command; ``{port}`` substituted per replica.
|
||||||
|
serve_command: str | None = None
|
||||||
|
|
||||||
|
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
|
||||||
|
parallel_servers: int = 1
|
||||||
|
num_gpus: int = 0
|
||||||
|
client_concurrency: int = 16
|
||||||
|
serve_ready_timeout_s: float = 600.0
|
||||||
|
|
||||||
|
max_new_tokens: int = 512
|
||||||
|
temperature: float = 0.2
|
||||||
|
|
||||||
|
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
|
||||||
|
max_model_len: int | None = None
|
||||||
|
|
||||||
|
# Camera for keyframes; None → first ``observation.images.*`` key.
|
||||||
|
camera_key: str | None = None
|
||||||
|
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutorConfig:
|
||||||
|
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
|
||||||
|
|
||||||
|
# Episodes processed concurrently per phase; main knob for saturating the servers.
|
||||||
|
episode_parallelism: int = 16
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AnnotationPipelineConfig:
|
||||||
|
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||||
|
|
||||||
|
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
|
||||||
|
# is on and ``new_repo_id`` unset.
|
||||||
|
repo_id: str | None = None
|
||||||
|
|
||||||
|
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
|
||||||
|
new_repo_id: str | None = None
|
||||||
|
|
||||||
|
root: Path | None = None
|
||||||
|
|
||||||
|
# Defaults to ``<root>/.annotate_staging/``.
|
||||||
|
staging_dir: Path | None = None
|
||||||
|
|
||||||
|
seed: int = 1729
|
||||||
|
|
||||||
|
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||||
|
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||||
|
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||||
|
|
||||||
|
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||||
|
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||||
|
|
||||||
|
skip_validation: bool = False
|
||||||
|
only_episodes: tuple[int, ...] | None = None
|
||||||
|
|
||||||
|
# Keyframe decode backend forwarded to ``decode_video_frames``. None →
|
||||||
|
# library default (torchcodec when available, else PyAV). Or pin
|
||||||
|
# ``"torchcodec"`` / ``"pyav"`` explicitly.
|
||||||
|
video_backend: str | None = None
|
||||||
|
|
||||||
|
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
|
||||||
|
push_to_hub: bool = False
|
||||||
|
push_private: bool = False
|
||||||
|
push_commit_message: str | None = None
|
||||||
|
|
||||||
|
def resolved_staging_dir(self, root: Path) -> Path:
|
||||||
|
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||||
@@ -0,0 +1,253 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""In-process executor that runs the annotation phases.
|
||||||
|
|
||||||
|
The executor runs **six phases** in dependency order:
|
||||||
|
|
||||||
|
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||||
|
phase 2: ``interjections`` module (interjections + speech)
|
||||||
|
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||||
|
interjection timestamp produced by phase 2
|
||||||
|
phase 4: ``vqa`` module (VQA)
|
||||||
|
phase 5: validator
|
||||||
|
phase 6: writer
|
||||||
|
|
||||||
|
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||||
|
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||||
|
timestamps.
|
||||||
|
|
||||||
|
Distributed execution is provided by Hugging Face Jobs (see
|
||||||
|
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||||
|
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||||
|
Episode-level concurrency is controlled by
|
||||||
|
``ExecutorConfig.episode_parallelism``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import AnnotationPipelineConfig
|
||||||
|
from .reader import EpisodeRecord, iter_episodes
|
||||||
|
from .staging import EpisodeStaging
|
||||||
|
from .validator import StagingValidator
|
||||||
|
from .writer import LanguageColumnsWriter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PhaseResult:
|
||||||
|
"""Summary of one pipeline phase across all episodes."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
episodes_processed: int
|
||||||
|
episodes_skipped: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineRunSummary:
|
||||||
|
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||||
|
|
||||||
|
phases: list[PhaseResult]
|
||||||
|
written_paths: list[Path]
|
||||||
|
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Executor:
|
||||||
|
"""Run all six phases over a dataset root in-process.
|
||||||
|
|
||||||
|
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||||
|
(a thread pool); cluster-level concurrency comes from running this
|
||||||
|
executor inside a Hugging Face Job. Tests construct the executor
|
||||||
|
directly with stub modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: AnnotationPipelineConfig
|
||||||
|
plan: Any # PlanSubtasksMemoryModule
|
||||||
|
interjections: Any # InterjectionsAndSpeechModule
|
||||||
|
vqa: Any # GeneralVqaModule
|
||||||
|
writer: LanguageColumnsWriter
|
||||||
|
validator: StagingValidator
|
||||||
|
|
||||||
|
def run(self, root: Path) -> PipelineRunSummary:
|
||||||
|
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||||
|
n = len(records)
|
||||||
|
if n == 0:
|
||||||
|
raise ValueError(f"No episodes found under {root}/data/")
|
||||||
|
|
||||||
|
print(f"[annotate] {n} episodes total", flush=True)
|
||||||
|
|
||||||
|
staging_dir = self.config.resolved_staging_dir(root)
|
||||||
|
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
phases: list[PhaseResult] = []
|
||||||
|
|
||||||
|
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||||
|
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||||
|
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||||
|
# reads the ``plan`` module's subtask rows from the same staging
|
||||||
|
# tree to ground the interjection prompt in the correct local subtask.
|
||||||
|
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||||
|
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||||
|
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||||
|
# Phase 4: ``vqa`` module (VQA)
|
||||||
|
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||||
|
|
||||||
|
print("[annotate] running validator...", flush=True)
|
||||||
|
report = self.validator.validate(records, staging_dir)
|
||||||
|
if not report.ok and not self.config.skip_validation:
|
||||||
|
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||||
|
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||||
|
|
||||||
|
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||||
|
written = self.writer.write_all(records, staging_dir, root)
|
||||||
|
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||||
|
|
||||||
|
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||||
|
# Idempotent and additive: existing user metadata is preserved.
|
||||||
|
self._ensure_annotation_metadata_in_info(root)
|
||||||
|
|
||||||
|
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||||
|
"""Write language features and canonical tools to ``meta/info.json``.
|
||||||
|
|
||||||
|
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||||
|
``language_events`` to parquet shards. The metadata must advertise
|
||||||
|
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||||
|
cast against the old schema and fail on the extra parquet columns.
|
||||||
|
"""
|
||||||
|
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||||
|
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||||
|
|
||||||
|
info_path = root / "meta" / "info.json"
|
||||||
|
if not info_path.exists():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
info = load_info(root)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
merged_features = {**info.features, **language_feature_info()}
|
||||||
|
if merged_features != info.features:
|
||||||
|
info.features = merged_features
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
existing = info.tools or []
|
||||||
|
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||||
|
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||||
|
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if changed:
|
||||||
|
write_info(info, root)
|
||||||
|
print(
|
||||||
|
"[annotate] meta/info.json: "
|
||||||
|
f"language_features={list(language_feature_info())}, "
|
||||||
|
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_module_phase(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
records: list[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
module: Any,
|
||||||
|
) -> PhaseResult:
|
||||||
|
if not module.enabled:
|
||||||
|
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||||
|
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||||
|
n = len(records)
|
||||||
|
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||||
|
print(
|
||||||
|
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||||
|
i, record = idx_record
|
||||||
|
ep_start = time.time()
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
return i, record.episode_index, time.time() - ep_start
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
if parallelism == 1:
|
||||||
|
for i, record in enumerate(records, 1):
|
||||||
|
_, ep_idx, elapsed = _do((i, record))
|
||||||
|
processed += 1
|
||||||
|
print(
|
||||||
|
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||||
|
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||||
|
for fut in as_completed(futures):
|
||||||
|
i, ep_idx, elapsed = fut.result()
|
||||||
|
processed += 1
|
||||||
|
print(
|
||||||
|
f"[annotate] {name} episode {processed}/{n} "
|
||||||
|
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
total = time.time() - t0
|
||||||
|
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||||
|
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||||
|
|
||||||
|
def _run_plan_update_phase( # noqa: PLR0915
|
||||||
|
self, records: list[EpisodeRecord], staging_dir: Path
|
||||||
|
) -> PhaseResult:
|
||||||
|
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||||
|
|
||||||
|
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||||
|
produced the timestamps. This phase therefore calls back into the
|
||||||
|
``plan`` module with the interjection timestamps so its existing
|
||||||
|
prompt path is reused.
|
||||||
|
"""
|
||||||
|
if not self.plan.enabled or not self.interjections.enabled:
|
||||||
|
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
|
||||||
|
processed = 0
|
||||||
|
for record in records:
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
interjection_rows = [
|
||||||
|
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||||
|
]
|
||||||
|
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||||
|
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||||
|
if interjection_times:
|
||||||
|
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||||
|
processed += 1
|
||||||
|
# Episodes without any interjections are skipped (no plan refresh
|
||||||
|
# needed); count them so the summary's processed+skipped == total.
|
||||||
|
return PhaseResult(
|
||||||
|
name="plan_update",
|
||||||
|
episodes_processed=processed,
|
||||||
|
episodes_skipped=len(records) - processed,
|
||||||
|
)
|
||||||
@@ -0,0 +1,481 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Keyframe extraction for the annotation pipeline.
|
||||||
|
|
||||||
|
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||||
|
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||||
|
visual content. The pipeline shares one provider across modules and one
|
||||||
|
episode at a time, with a small per-episode cache so multiple modules
|
||||||
|
querying the same timestamp pay decode cost once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.video import VideoEncoderConfig
|
||||||
|
from lerobot.datasets.video_utils import decode_video_frames, reencode_video
|
||||||
|
|
||||||
|
from .reader import EpisodeRecord, snap_to_frame
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameProvider(Protocol):
|
||||||
|
"""Decodes camera frames at episode-relative timestamps."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||||
|
|
||||||
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||||
|
|
||||||
|
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||||
|
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||||
|
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||||
|
boundary.
|
||||||
|
|
||||||
|
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||||
|
to the provider's default camera so existing single-camera callers
|
||||||
|
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||||
|
|
||||||
|
Sampling is uniform across the episode duration. Frames are
|
||||||
|
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||||
|
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||||
|
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||||
|
no camera available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _NullProvider:
|
||||||
|
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def null_provider() -> FrameProvider:
|
||||||
|
return _NullProvider()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoFrameProvider:
|
||||||
|
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||||
|
|
||||||
|
By default the *first* camera key is used for the ``plan`` module
|
||||||
|
(subtask decomposition) and the ``interjections`` module (interjection
|
||||||
|
scenarios) — those prompts care about *what is happening*, not which
|
||||||
|
angle. The ``vqa`` module instead iterates over every camera in
|
||||||
|
:attr:`camera_keys` so each frame's
|
||||||
|
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||||
|
grounded against.
|
||||||
|
|
||||||
|
``camera_key`` overrides the default-camera choice but does not restrict
|
||||||
|
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||||
|
``video_for_episode`` to read a non-default stream.
|
||||||
|
|
||||||
|
Caches up to ``cache_size`` decoded frames per process to keep
|
||||||
|
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
root: Path
|
||||||
|
camera_key: str | None = None
|
||||||
|
tolerance_s: float = 1e-2
|
||||||
|
cache_size: int = 256
|
||||||
|
# Keyframe decode backend forwarded to
|
||||||
|
# :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None``
|
||||||
|
# uses the library default (torchcodec when available, else PyAV).
|
||||||
|
video_backend: str | None = None
|
||||||
|
_meta: Any = field(default=None, init=False, repr=False)
|
||||||
|
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||||
|
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||||
|
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||||
|
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||||
|
# one-shot warn flag against concurrent updates from worker threads.
|
||||||
|
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||||
|
# Serializes decode_video_frames calls: torchcodec hands out one
|
||||||
|
# ``VideoDecoder`` per file from a process-wide cache, and the decoder
|
||||||
|
# is not safe to drive from multiple threads at once.
|
||||||
|
_decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||||
|
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||||
|
|
||||||
|
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||||
|
# Only ``video_keys`` are decodable here: the clip/decode paths read
|
||||||
|
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
|
||||||
|
# only for video-stored cameras. Image-stored cameras (also in
|
||||||
|
# ``camera_keys``) would KeyError, so restrict the list — and the
|
||||||
|
# default — to video keys.
|
||||||
|
keys = list(self._meta.video_keys)
|
||||||
|
# Last-resort fallback: if metadata didn't surface any video keys but
|
||||||
|
# the caller explicitly named a camera (``--vlm.camera_key=...``),
|
||||||
|
# trust them — the key is by definition known to exist on the dataset.
|
||||||
|
if not keys and self.camera_key:
|
||||||
|
keys = [self.camera_key]
|
||||||
|
self._camera_keys = keys
|
||||||
|
if self.camera_key is None:
|
||||||
|
self.camera_key = keys[0] if keys else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""All ``observation.images.*`` keys available on this dataset."""
|
||||||
|
return list(self._camera_keys)
|
||||||
|
|
||||||
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
target = camera_key if camera_key is not None else self.camera_key
|
||||||
|
if not timestamps or target is None:
|
||||||
|
return []
|
||||||
|
# Snap each request to the nearest real frame timestamp: callers
|
||||||
|
# sample uniform grids whose points land mid-frame, and
|
||||||
|
# ``decode_video_frames`` rejects queries farther than
|
||||||
|
# ``tolerance_s`` from a decodable frame. Snapping also dedupes
|
||||||
|
# repeat queries through the cache.
|
||||||
|
if record.frame_timestamps:
|
||||||
|
timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps]
|
||||||
|
|
||||||
|
out: list[Any] = []
|
||||||
|
misses: list[float] = []
|
||||||
|
miss_indices: list[int] = []
|
||||||
|
with self._lock:
|
||||||
|
for i, ts in enumerate(timestamps):
|
||||||
|
key = (record.episode_index, target, round(float(ts), 6))
|
||||||
|
cached = self._cache.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
out.append(cached)
|
||||||
|
else:
|
||||||
|
out.append(None)
|
||||||
|
misses.append(float(ts))
|
||||||
|
miss_indices.append(i)
|
||||||
|
|
||||||
|
if misses:
|
||||||
|
decoded = self._decode(record.episode_index, misses, target)
|
||||||
|
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||||
|
# or an empty list if decoding failed wholesale. A partial list
|
||||||
|
# would mean a frame/timestamp misalignment, so only pair them up
|
||||||
|
# when the counts match (``strict=True`` then guards regressions).
|
||||||
|
if len(decoded) == len(miss_indices):
|
||||||
|
with self._lock:
|
||||||
|
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||||
|
out[i] = frame
|
||||||
|
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||||
|
if len(self._cache) >= self.cache_size:
|
||||||
|
self._cache.pop(next(iter(self._cache)))
|
||||||
|
self._cache[key] = frame
|
||||||
|
# filter out any None left over from decode failures
|
||||||
|
return [frame for frame in out if frame is not None]
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||||
|
|
||||||
|
The whole episode duration is covered; the model picks subtask
|
||||||
|
boundaries from the temporal pooling it does internally. Frames are
|
||||||
|
``torch.Tensor`` (see :meth:`frames_at`).
|
||||||
|
"""
|
||||||
|
target = camera_key if camera_key is not None else self.camera_key
|
||||||
|
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||||
|
return []
|
||||||
|
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||||
|
if n_frames == len(record.frame_timestamps):
|
||||||
|
timestamps = list(record.frame_timestamps)
|
||||||
|
else:
|
||||||
|
t0 = record.frame_timestamps[0]
|
||||||
|
t_last = record.frame_timestamps[-1]
|
||||||
|
if t_last <= t0:
|
||||||
|
timestamps = [float(t0)] * n_frames
|
||||||
|
else:
|
||||||
|
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||||
|
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||||
|
return self.frames_at(record, timestamps, camera_key=target)
|
||||||
|
|
||||||
|
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||||
|
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||||
|
|
||||||
|
Returns ``None`` if the dataset has no video tracks or extraction
|
||||||
|
failed. Skips re-extract when the cached clip already exists.
|
||||||
|
Re-encodes to H.264 via
|
||||||
|
:func:`lerobot.datasets.video_utils.reencode_video` so the resulting
|
||||||
|
mp4 is decodable by every downstream video processor — stream-copy
|
||||||
|
would inherit the source codec (often AV1 in modern LeRobot
|
||||||
|
datasets), which vllm's libav build cannot decode.
|
||||||
|
"""
|
||||||
|
if self.camera_key is None:
|
||||||
|
return None
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||||
|
if out_path.exists() and out_path.stat().st_size > 0:
|
||||||
|
return out_path
|
||||||
|
ep = self._meta.episodes[record.episode_index]
|
||||||
|
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||||
|
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||||
|
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||||
|
encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast")
|
||||||
|
try:
|
||||||
|
reencode_video(
|
||||||
|
src,
|
||||||
|
out_path,
|
||||||
|
camera_encoder=encoder,
|
||||||
|
overwrite=True,
|
||||||
|
start_time_s=from_timestamp,
|
||||||
|
end_time_s=to_timestamp,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||||
|
|
||||||
|
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||||
|
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||||
|
|
||||||
|
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||||
|
(torchcodec when available, PyAV otherwise; ``video_backend`` pins
|
||||||
|
one explicitly). Returns one frame per requested timestamp, or ``[]``
|
||||||
|
if decoding failed — callers treat ``[]`` as "no frames available".
|
||||||
|
"""
|
||||||
|
ep = self._meta.episodes[episode_index]
|
||||||
|
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||||
|
shifted = [from_timestamp + ts for ts in timestamps]
|
||||||
|
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# The module phases decode under a ThreadPoolExecutor (see
|
||||||
|
# ``ExecutorConfig.episode_parallelism``) but torchcodec's cached
|
||||||
|
# per-file decoder is single-threaded, so serialize decodes on a
|
||||||
|
# dedicated lock. Frame extraction is a small fraction of episode
|
||||||
|
# wall time (VLM calls dominate), so the contention is cheap.
|
||||||
|
with self._decode_lock:
|
||||||
|
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||||
|
decoded = decode_video_frames(
|
||||||
|
video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True
|
||||||
|
)
|
||||||
|
return list(decoded)
|
||||||
|
except Exception as exc:
|
||||||
|
# Log loudly the first time so a silent vqa-module no-op (every
|
||||||
|
# prompt skipped because frames_at returned []) is debuggable from
|
||||||
|
# the job log instead of post-hoc parquet inspection. Subsequent
|
||||||
|
# failures stay quiet.
|
||||||
|
with self._lock:
|
||||||
|
already_warned = self._warned_decode_fail
|
||||||
|
if not already_warned:
|
||||||
|
self._warned_decode_fail = True
|
||||||
|
if not already_warned:
|
||||||
|
logger.warning(
|
||||||
|
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s",
|
||||||
|
episode_index,
|
||||||
|
camera_key,
|
||||||
|
video_path,
|
||||||
|
self.video_backend,
|
||||||
|
exc,
|
||||||
|
exc_info=exc,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def make_frame_provider(
|
||||||
|
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||||
|
) -> FrameProvider:
|
||||||
|
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||||
|
try:
|
||||||
|
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||||
|
except Exception:
|
||||||
|
return null_provider()
|
||||||
|
if provider.camera_key is None:
|
||||||
|
return null_provider()
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def _frame_to_pil(frame: Any) -> Any:
|
||||||
|
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||||
|
|
||||||
|
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||||
|
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||||
|
the VLM-message boundary, because the chat backends expect PIL images /
|
||||||
|
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||||
|
"""
|
||||||
|
if not isinstance(frame, torch.Tensor):
|
||||||
|
return frame
|
||||||
|
array = frame.detach().cpu()
|
||||||
|
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||||
|
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||||
|
if array.shape[-1] == 1:
|
||||||
|
array = array.squeeze(-1)
|
||||||
|
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||||
|
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||||
|
|
||||||
|
|
||||||
|
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||||
|
|
||||||
|
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||||
|
into a content array without a separate emptiness check.
|
||||||
|
"""
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||||
|
|
||||||
|
|
||||||
|
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||||
|
"""Wrap a video file URL as one ``video_url`` block.
|
||||||
|
|
||||||
|
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||||
|
ktransformers serve), where the server handles frame sampling.
|
||||||
|
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
return []
|
||||||
|
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image:
|
||||||
|
"""Burn ``timestamp`` (seconds) into the top-left corner of ``image``.
|
||||||
|
|
||||||
|
A solid black badge with white text, so a VLM reading a contact sheet can
|
||||||
|
cite the exact source time of each tile (e.g. ``012.50s``) directly,
|
||||||
|
instead of the caller having to map tile position back to time. Mirrors
|
||||||
|
the macrodata/refiner contact-sheet convention.
|
||||||
|
"""
|
||||||
|
from PIL import ImageDraw, ImageFont
|
||||||
|
|
||||||
|
result = image.copy()
|
||||||
|
draw = ImageDraw.Draw(result)
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
label = f"{timestamp:06.2f}s"
|
||||||
|
left, top, right, bottom = draw.textbbox((0, 0), label, font=font)
|
||||||
|
text_w, text_h = right - left, bottom - top
|
||||||
|
pad = max(3, round(min(image.width, image.height) * 0.018))
|
||||||
|
draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0))
|
||||||
|
draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def to_contact_sheet_blocks(
|
||||||
|
frames: Sequence[Any],
|
||||||
|
timestamps: Sequence[float],
|
||||||
|
*,
|
||||||
|
columns: int = 5,
|
||||||
|
frames_per_sheet: int = 20,
|
||||||
|
frame_width: int = 224,
|
||||||
|
quality: int = 84,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Pack decoded frames into timestamped JPEG contact-sheet image blocks.
|
||||||
|
|
||||||
|
Each frame is resized to ``frame_width`` wide, stamped with its
|
||||||
|
episode-relative timestamp, and tiled row-major into grids of
|
||||||
|
``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}``
|
||||||
|
block is returned per grid; many frames collapse into a few images, so a
|
||||||
|
long episode's temporal coverage stays dense at a fraction of the vision
|
||||||
|
tokens N separate frames would cost. ``frames`` and ``timestamps`` must be
|
||||||
|
aligned and equal length. Returns ``[]`` for empty input.
|
||||||
|
"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
columns = max(1, columns)
|
||||||
|
frames_per_sheet = max(1, frames_per_sheet)
|
||||||
|
rows_per_sheet = math.ceil(frames_per_sheet / columns)
|
||||||
|
|
||||||
|
tiles: list[PIL.Image.Image] = []
|
||||||
|
for ts, frame in zip(timestamps, frames, strict=False):
|
||||||
|
img = _frame_to_pil(frame)
|
||||||
|
if not isinstance(img, PIL.Image.Image):
|
||||||
|
continue
|
||||||
|
img = img.convert("RGB")
|
||||||
|
if img.width != frame_width:
|
||||||
|
height = max(1, round(img.height * frame_width / img.width))
|
||||||
|
img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR)
|
||||||
|
tiles.append(_draw_timestamp_badge(img, float(ts)))
|
||||||
|
if not tiles:
|
||||||
|
return []
|
||||||
|
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
for start in range(0, len(tiles), frames_per_sheet):
|
||||||
|
chunk = tiles[start : start + frames_per_sheet]
|
||||||
|
cell_w = max(tile.width for tile in chunk)
|
||||||
|
cell_h = max(tile.height for tile in chunk)
|
||||||
|
sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0))
|
||||||
|
for i, tile in enumerate(chunk):
|
||||||
|
x = (i % columns) * cell_w
|
||||||
|
y = (i // columns) * cell_h
|
||||||
|
sheet.paste(tile, (x, y))
|
||||||
|
# JPEG round-trip at ``quality`` to match the refiner convention and
|
||||||
|
# shrink the wire payload; vision-token count is set by resolution, so
|
||||||
|
# the real saving is the grid packing, not the codec.
|
||||||
|
buf = io.BytesIO()
|
||||||
|
sheet.save(buf, format="JPEG", quality=quality)
|
||||||
|
buf.seek(0)
|
||||||
|
blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")})
|
||||||
|
return blocks
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .general_vqa import GeneralVqaModule
|
||||||
|
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||||
|
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GeneralVqaModule",
|
||||||
|
"InterjectionsAndSpeechModule",
|
||||||
|
"PlanSubtasksMemoryModule",
|
||||||
|
]
|
||||||
@@ -0,0 +1,248 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``vqa`` module: general VQA at a timed cadence.
|
||||||
|
|
||||||
|
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||||
|
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||||
|
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||||
|
window. For datasets with multiple cameras, every anchored frame produces
|
||||||
|
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||||
|
generated against that camera's frame and stamped with the matching
|
||||||
|
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||||
|
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||||
|
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||||
|
|
||||||
|
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||||
|
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||||
|
|
||||||
|
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||||
|
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||||
|
whose schema depends on the question type. Malformed JSON triggers one
|
||||||
|
retry inside :meth:`VlmClient.generate_json`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import VqaConfig
|
||||||
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||||
|
from ..prompts import load as load_prompt
|
||||||
|
from ..reader import EpisodeRecord
|
||||||
|
from ..staging import EpisodeStaging
|
||||||
|
from ..validator import classify_vqa_answer
|
||||||
|
from ..vlm_client import VlmClient
|
||||||
|
|
||||||
|
|
||||||
|
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||||
|
"""Return the relative frame indices to anchor VQA emissions to.
|
||||||
|
|
||||||
|
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||||
|
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||||
|
available source frame timestamp.
|
||||||
|
"""
|
||||||
|
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||||
|
return []
|
||||||
|
t0 = frame_timestamps[0]
|
||||||
|
t_last = frame_timestamps[-1]
|
||||||
|
period = 1.0 / hz
|
||||||
|
indices: list[int] = []
|
||||||
|
t = t0
|
||||||
|
while t <= t_last + 1e-9:
|
||||||
|
# find the index of the nearest frame to t
|
||||||
|
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||||
|
for offset in range(k):
|
||||||
|
j = nearest_i + offset
|
||||||
|
if j >= len(frame_timestamps):
|
||||||
|
break
|
||||||
|
if not indices or indices[-1] != j:
|
||||||
|
indices.append(j)
|
||||||
|
t += period
|
||||||
|
# dedupe while preserving order
|
||||||
|
seen: set[int] = set()
|
||||||
|
deduped: list[int] = []
|
||||||
|
for i in indices:
|
||||||
|
if i in seen:
|
||||||
|
continue
|
||||||
|
seen.add(i)
|
||||||
|
deduped.append(i)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneralVqaModule:
|
||||||
|
"""Emit grounded VQA pairs at a timed cadence."""
|
||||||
|
|
||||||
|
vlm: VlmClient
|
||||||
|
config: VqaConfig
|
||||||
|
seed: int = 1729
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
_warned_no_camera: bool = field(default=False, init=False, repr=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.config.enabled
|
||||||
|
|
||||||
|
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||||
|
if not record.frame_timestamps:
|
||||||
|
staging.write("vqa", [])
|
||||||
|
return
|
||||||
|
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||||
|
anchor_idx = _emission_anchor_indices(
|
||||||
|
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||||
|
)
|
||||||
|
cameras = self._target_cameras()
|
||||||
|
if not cameras:
|
||||||
|
# No camera available — emit nothing rather than producing
|
||||||
|
# untagged rows that would fail validation. Surface a loud one-
|
||||||
|
# time warning so this is never silently a no-op.
|
||||||
|
if not self._warned_no_camera:
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"vqa module found no cameras on the frame provider — "
|
||||||
|
"every episode will emit zero VQA rows. Check that the "
|
||||||
|
"dataset declares observation.images.* features in "
|
||||||
|
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||||
|
"CLI now also seeds the cameras list as a fallback."
|
||||||
|
)
|
||||||
|
self._warned_no_camera = True
|
||||||
|
staging.write("vqa", [])
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build all messages first (one per (frame, camera)), then issue them
|
||||||
|
# as a single batched generate_json call so the client can fan them
|
||||||
|
# out concurrently.
|
||||||
|
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||||
|
for idx in anchor_idx:
|
||||||
|
ts = float(record.frame_timestamps[idx])
|
||||||
|
qtype = rng.choice(self.config.question_types)
|
||||||
|
for camera in cameras:
|
||||||
|
messages = self._build_messages(record, qtype, ts, camera)
|
||||||
|
# Skip cameras that decoded to zero frames at this ts: no point
|
||||||
|
# asking the VLM to ground a bbox without an image.
|
||||||
|
if not _has_image_block(messages):
|
||||||
|
continue
|
||||||
|
per_call.append((ts, camera, qtype, messages))
|
||||||
|
|
||||||
|
if not per_call:
|
||||||
|
staging.write("vqa", [])
|
||||||
|
return
|
||||||
|
|
||||||
|
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||||
|
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||||
|
qa = self._postprocess(result)
|
||||||
|
if qa is None:
|
||||||
|
continue
|
||||||
|
question, answer = qa
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": question,
|
||||||
|
"style": "vqa",
|
||||||
|
"timestamp": ts,
|
||||||
|
"camera": camera,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": json.dumps(answer, sort_keys=True),
|
||||||
|
"style": "vqa",
|
||||||
|
"timestamp": ts,
|
||||||
|
"camera": camera,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
staging.write("vqa", rows)
|
||||||
|
|
||||||
|
def _target_cameras(self) -> list[str]:
|
||||||
|
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||||
|
|
||||||
|
Defaults to every camera the provider exposes. Datasets with no
|
||||||
|
cameras (or test/null providers) yield an empty list, which makes
|
||||||
|
``run_episode`` a no-op.
|
||||||
|
|
||||||
|
When ``config.restrict_to_default_camera`` is set, VQA grounds on
|
||||||
|
only the provider's default camera (the single ``--vlm.camera_key``
|
||||||
|
stream), matching the plan / interjection modules so the whole
|
||||||
|
pipeline focuses on one view.
|
||||||
|
"""
|
||||||
|
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||||
|
if getattr(self.config, "restrict_to_default_camera", False):
|
||||||
|
default = getattr(self.frame_provider, "camera_key", None)
|
||||||
|
if default and default in all_cameras:
|
||||||
|
return [default]
|
||||||
|
# ``restrict_to_default_camera`` is set but the configured default
|
||||||
|
# isn't one the provider exposes. Returning it anyway would make
|
||||||
|
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
|
||||||
|
# fall through to every available camera instead.
|
||||||
|
if default:
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"restrict_to_default_camera is set but camera_key=%r is not in the "
|
||||||
|
"provider's cameras %s; grounding VQA on all available cameras instead.",
|
||||||
|
default,
|
||||||
|
all_cameras,
|
||||||
|
)
|
||||||
|
return all_cameras
|
||||||
|
|
||||||
|
def _build_messages(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
question_type: str,
|
||||||
|
frame_timestamp: float,
|
||||||
|
camera_key: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
prompt = load_prompt("vqa").format(
|
||||||
|
episode_task=record.episode_task,
|
||||||
|
question_type=question_type,
|
||||||
|
)
|
||||||
|
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
|
||||||
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
|
return [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return None
|
||||||
|
question = result.get("question")
|
||||||
|
answer = result.get("answer")
|
||||||
|
if not isinstance(question, str) or not question.strip():
|
||||||
|
return None
|
||||||
|
if not isinstance(answer, dict):
|
||||||
|
return None
|
||||||
|
# The validator will enforce shape; here we just sanity-check that the
|
||||||
|
# answer matches *some* known shape so we can drop garbage early.
|
||||||
|
if classify_vqa_answer(answer) is None:
|
||||||
|
return None
|
||||||
|
return question.strip(), answer
|
||||||
|
|
||||||
|
|
||||||
|
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Return True if any user content block is a populated image block."""
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
continue
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "image":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||||
|
|
||||||
|
Two sub-passes:
|
||||||
|
|
||||||
|
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||||
|
canonical task). No interjection row — the canonical task is already the
|
||||||
|
user utterance from ``meta/tasks.parquet``.
|
||||||
|
|
||||||
|
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||||
|
{role:user, style:interjection, content:<text>}
|
||||||
|
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||||
|
Both rows go in ``language_events`` at the same timestamp.
|
||||||
|
|
||||||
|
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||||
|
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import InterjectionsConfig
|
||||||
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||||
|
from ..prompts import load as load_prompt
|
||||||
|
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||||
|
from ..staging import EpisodeStaging
|
||||||
|
from ..vlm_client import VlmClient
|
||||||
|
from ..writer import speech_atom
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterjectionsAndSpeechModule:
|
||||||
|
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||||
|
|
||||||
|
vlm: VlmClient
|
||||||
|
config: InterjectionsConfig
|
||||||
|
seed: int = 1729
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.config.enabled
|
||||||
|
|
||||||
|
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
if record.frame_timestamps:
|
||||||
|
t0 = float(record.frame_timestamps[0])
|
||||||
|
initial = self._initial_speech(record)
|
||||||
|
if initial:
|
||||||
|
rows.append(speech_atom(t0, initial))
|
||||||
|
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||||
|
# interjection prompt can ground itself in the actual current
|
||||||
|
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||||
|
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||||
|
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||||
|
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||||
|
staging.write("interjections", rows)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||||
|
current: str | None = None
|
||||||
|
for span in spans:
|
||||||
|
if float(span["start"]) <= t:
|
||||||
|
current = span.get("text")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return current
|
||||||
|
|
||||||
|
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||||
|
prompt = load_prompt("interjections_initial_speech").format(
|
||||||
|
episode_task=record.episode_task,
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||||
|
text = result["text"].strip()
|
||||||
|
if text:
|
||||||
|
return text
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _mid_episode_interjections(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
subtask_spans: Sequence[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Generate interjections aligned with the actual demo trajectory.
|
||||||
|
|
||||||
|
Teleop data is frozen — the robot already executed every step in
|
||||||
|
the video. A *counterfactual* interjection like "actually skip
|
||||||
|
the wipe" contradicts what then happens in the video, which is
|
||||||
|
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||||
|
|
||||||
|
Instead, anchor every interjection at a subtask boundary and
|
||||||
|
write it as a natural user request for the *upcoming* subtask.
|
||||||
|
The robot's visible next behavior IS the interjection's effect,
|
||||||
|
so the training signal stays consistent: interjection text →
|
||||||
|
plan refresh → action stream all line up.
|
||||||
|
"""
|
||||||
|
if self.config.max_interjections_per_episode <= 0:
|
||||||
|
return []
|
||||||
|
if len(subtask_spans) < 2:
|
||||||
|
# Need at least one transition (subtask 0 → subtask 1).
|
||||||
|
return []
|
||||||
|
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||||
|
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||||
|
|
||||||
|
# Boundaries: the start time of every subtask except the first
|
||||||
|
# (which is just t0 and is covered by the initial-task speech atom).
|
||||||
|
boundaries: list[tuple[float, str, str]] = []
|
||||||
|
for i in range(1, len(subtask_spans)):
|
||||||
|
ts = float(subtask_spans[i]["start"])
|
||||||
|
if ts < self.config.interjection_min_t:
|
||||||
|
continue
|
||||||
|
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||||
|
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||||
|
if not next_text:
|
||||||
|
continue
|
||||||
|
boundaries.append((ts, prev_text, next_text))
|
||||||
|
if not boundaries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||||
|
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||||
|
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for t, prev_subtask, next_subtask in chosen:
|
||||||
|
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||||
|
# Window straddles the boundary so the VLM sees the end of the
|
||||||
|
# previous subtask and the start of the next one — same
|
||||||
|
# conditioning the policy will see at training time.
|
||||||
|
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||||
|
prompt = load_prompt("interjections_interjection").format(
|
||||||
|
episode_task=record.episode_task,
|
||||||
|
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||||
|
next_subtask=next_subtask,
|
||||||
|
timestamp=t_snap,
|
||||||
|
window_seconds=self.config.interjection_window_seconds,
|
||||||
|
)
|
||||||
|
images = self.frame_provider.frames_at(record, window_ts)
|
||||||
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
continue
|
||||||
|
interjection_text = result.get("interjection")
|
||||||
|
speech_text = result.get("speech")
|
||||||
|
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||||
|
continue
|
||||||
|
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||||
|
continue
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": interjection_text.strip(),
|
||||||
|
"style": "interjection",
|
||||||
|
"timestamp": t_snap,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||||
|
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||||
|
|
||||||
|
The window straddles the subtask boundary the interjection sits
|
||||||
|
on: roughly half the frames cover the end of the previous
|
||||||
|
subtask, half cover the start of the next one. The VLM therefore
|
||||||
|
sees BOTH what just finished AND what's about to start, which is
|
||||||
|
the conditioning we need to write a natural "now please do X"
|
||||||
|
request that matches the visible upcoming behavior.
|
||||||
|
"""
|
||||||
|
if not frame_timestamps:
|
||||||
|
return [t_anchor]
|
||||||
|
n = max(1, int(self.config.interjection_window_frames))
|
||||||
|
if n == 1:
|
||||||
|
return [t_anchor]
|
||||||
|
window = float(self.config.interjection_window_seconds)
|
||||||
|
step = window / max(1, n - 1)
|
||||||
|
# Center the window on the anchor so half lands before, half after.
|
||||||
|
start_offset = -window / 2.0
|
||||||
|
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||||
|
first_ts = float(frame_timestamps[0])
|
||||||
|
last_ts = float(frame_timestamps[-1])
|
||||||
|
snapped: list[float] = []
|
||||||
|
seen: set[float] = set()
|
||||||
|
for tgt in targets:
|
||||||
|
clamped = min(last_ts, max(first_ts, tgt))
|
||||||
|
t = snap_to_frame(clamped, frame_timestamps)
|
||||||
|
if t not in seen:
|
||||||
|
seen.add(t)
|
||||||
|
snapped.append(t)
|
||||||
|
return snapped or [t_anchor]
|
||||||
@@ -0,0 +1,780 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import PlanConfig
|
||||||
|
from ..frames import (
|
||||||
|
FrameProvider,
|
||||||
|
null_provider,
|
||||||
|
to_contact_sheet_blocks,
|
||||||
|
)
|
||||||
|
from ..prompts import load as load_prompt
|
||||||
|
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||||
|
from ..staging import EpisodeStaging
|
||||||
|
from ..vlm_client import VlmClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Prepended to every describe / segment prompt so the VLM knows the images are
|
||||||
|
# timestamped contact-sheet grids, not a single video, and reads the burned-in
|
||||||
|
# per-tile timestamp when choosing boundaries.
|
||||||
|
def _contact_sheet_preamble(columns: int) -> str:
|
||||||
|
return (
|
||||||
|
"CONTACT SHEETS — how to read the images below:\n"
|
||||||
|
f"- Each image is a grid of sampled video frames, {columns} per row, "
|
||||||
|
"with time running left-to-right then top-to-bottom (row-major).\n"
|
||||||
|
"- Each frame has its timestamp burned into the top-left corner, e.g. "
|
||||||
|
'"012.50s". Use that printed timestamp (not the tile position) when you '
|
||||||
|
"choose start/end times; boundaries should land on or near a printed "
|
||||||
|
"timestamp.\n"
|
||||||
|
"- Frames continue across grids: an action may span the end of one sheet "
|
||||||
|
"and the start of the next, so do not place a boundary just because a new "
|
||||||
|
"image begins.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Appended to every describe (and segment) prompt. A visual, causal definition
|
||||||
|
# of where one event ends and the next begins — adapted from macrodata/refiner —
|
||||||
|
# to sharpen cut points while the existing prompt keeps owning the imperative
|
||||||
|
# phrasing.
|
||||||
|
_CAUSAL_BOUNDARY_RULES = (
|
||||||
|
"EVENT BOUNDARIES — where one event ends and the next begins:\n"
|
||||||
|
"- Start a new event whenever the world state changes: an object becomes "
|
||||||
|
"held (the gripper closes on it), an object is released (the gripper opens "
|
||||||
|
"and it stays put), an object reaches a new location, a lid/door/drawer "
|
||||||
|
"changes open/closed state, a tool starts or stops affecting a surface, or "
|
||||||
|
"contents visibly move (e.g. poured).\n"
|
||||||
|
"- If a single action changes the same state gradually and continuously, "
|
||||||
|
"keep it as ONE event — do not split it.\n"
|
||||||
|
"- If the same action repeats on different objects or target locations, "
|
||||||
|
"treat each repetition as a separate event.\n"
|
||||||
|
"- Do NOT create boundaries for idle time, camera motion, hesitation, or "
|
||||||
|
"tiny hand adjustments."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlanSubtasksMemoryModule:
|
||||||
|
"""Generate subtask spans, plan, and memory rows.
|
||||||
|
|
||||||
|
All output is persistent (lives in ``language_persistent``):
|
||||||
|
|
||||||
|
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||||
|
(snapped to an exact frame).
|
||||||
|
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||||
|
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||||
|
the ``interjections`` module completes).
|
||||||
|
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||||
|
timestamp from the second subtask onward).
|
||||||
|
"""
|
||||||
|
|
||||||
|
vlm: VlmClient
|
||||||
|
config: PlanConfig
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.config.enabled
|
||||||
|
|
||||||
|
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
# Task driving every plan-module prompt: canonical episode_task, or a
|
||||||
|
# video-derived one when it's empty/placeholder (see derive_task_*).
|
||||||
|
effective_task = self._resolve_effective_task(record)
|
||||||
|
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
|
||||||
|
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
|
||||||
|
# free-form n_task_rephrasings; the effective task is always emitted
|
||||||
|
# first so the rotation covers the source-of-truth phrasing.
|
||||||
|
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||||
|
variants: list[str] | None = None
|
||||||
|
if self.config.task_aug_axes.enabled and effective_task:
|
||||||
|
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
|
||||||
|
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||||
|
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||||
|
if variants is not None:
|
||||||
|
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
|
||||||
|
|
||||||
|
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||||
|
|
||||||
|
# subtask rows
|
||||||
|
for span in subtask_spans:
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": span["text"],
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Plan rows at every subtask boundary (incl. t=0). The plan is a
|
||||||
|
# numbered list of still-todo subtasks, so re-emitting at each
|
||||||
|
# boundary makes it shrink as work progresses — ${plan} at frame t is
|
||||||
|
# exactly what's left to do.
|
||||||
|
if self.config.emit_plan:
|
||||||
|
for span in subtask_spans:
|
||||||
|
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||||
|
plan_text = self._generate_plan(
|
||||||
|
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||||
|
)
|
||||||
|
if plan_text is not None:
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": plan_text,
|
||||||
|
"style": "plan",
|
||||||
|
"timestamp": float(boundary_t),
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# memory rows at every subtask boundary except the very first start;
|
||||||
|
# skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only).
|
||||||
|
prior_memory = ""
|
||||||
|
memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else []
|
||||||
|
for i, span in memory_boundaries:
|
||||||
|
completed = subtask_spans[i - 1]["text"]
|
||||||
|
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||||
|
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||||
|
if mem_text:
|
||||||
|
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": mem_text,
|
||||||
|
"style": "memory",
|
||||||
|
"timestamp": ts,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
prior_memory = mem_text
|
||||||
|
staging.write("plan", rows)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Task derivation + rephrasings
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"debug",
|
||||||
|
"test",
|
||||||
|
"tbd",
|
||||||
|
"todo",
|
||||||
|
"n/a",
|
||||||
|
"na",
|
||||||
|
"untitled",
|
||||||
|
"unnamed",
|
||||||
|
"default",
|
||||||
|
"placeholder",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||||
|
"""Decide which task string drives the ``plan`` module for this episode.
|
||||||
|
|
||||||
|
Returns the user-supplied ``record.episode_task`` unless
|
||||||
|
``derive_task_from_video`` says otherwise (see config docstring).
|
||||||
|
Falls back gracefully to the canonical task if video derivation
|
||||||
|
fails.
|
||||||
|
"""
|
||||||
|
canonical = (record.episode_task or "").strip()
|
||||||
|
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||||
|
if mode == "always":
|
||||||
|
derived = self._derive_task_from_video(record)
|
||||||
|
return derived or canonical
|
||||||
|
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||||
|
derived = self._derive_task_from_video(record)
|
||||||
|
if derived:
|
||||||
|
return derived
|
||||||
|
return canonical
|
||||||
|
|
||||||
|
def _task_seems_bad(self, task: str) -> bool:
|
||||||
|
if not task:
|
||||||
|
return True
|
||||||
|
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||||
|
return True
|
||||||
|
return task.lower() in self._PLACEHOLDER_TASKS
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
|
||||||
|
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
|
||||||
|
seen: set[str] = set()
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for phrasing in phrasings:
|
||||||
|
key = phrasing.strip()
|
||||||
|
if not key or key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
rows.append(
|
||||||
|
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
|
||||||
|
)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# VLM call helpers — every plan-module prompt follows the same shape:
|
||||||
|
# build messages → single VLM call → pull a named field.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||||
|
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||||
|
|
||||||
|
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||||
|
dance every prompt-call site needs.
|
||||||
|
"""
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result.get(field)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||||
|
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||||
|
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||||
|
|
||||||
|
def _video_message(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
prompt: str,
|
||||||
|
window: tuple[float, float] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""User message combining the (optionally windowed) contact sheets with ``prompt``.
|
||||||
|
|
||||||
|
The prompt is always prefixed with a short explanation of how to read
|
||||||
|
the timestamped grids, so the model treats them as one ordered
|
||||||
|
sequence of frames rather than unrelated images.
|
||||||
|
"""
|
||||||
|
prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt
|
||||||
|
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
|
||||||
|
return [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||||
|
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||||
|
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
|
||||||
|
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||||
|
|
||||||
|
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||||
|
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||||
|
if n <= 0 or not base_task:
|
||||||
|
return []
|
||||||
|
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
|
||||||
|
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||||
|
if not isinstance(raw, list):
|
||||||
|
return []
|
||||||
|
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||||
|
return [s for s in out if s][:n]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
|
||||||
|
"""One VLM call → variants along the 5-axis taxonomy.
|
||||||
|
|
||||||
|
Variants from all axes are flattened into a single list (the
|
||||||
|
downstream pipeline doesn't need to know about the per-axis
|
||||||
|
bucketing — every variant becomes a ``task_aug`` row). Order
|
||||||
|
is preserved for reproducibility: synonym_paraphrase first,
|
||||||
|
then omit_arm, then omit_orientation, then omit_grasp_method,
|
||||||
|
then combined_omissions.
|
||||||
|
"""
|
||||||
|
if not base_task:
|
||||||
|
return []
|
||||||
|
prompt = load_prompt("plan_task_aug_axes").format(
|
||||||
|
base_task=base_task,
|
||||||
|
n_synonym=axes_cfg.synonym_paraphrase,
|
||||||
|
n_omit_arm=axes_cfg.omit_arm,
|
||||||
|
n_omit_orientation=axes_cfg.omit_orientation,
|
||||||
|
n_omit_grasp_method=axes_cfg.omit_grasp_method,
|
||||||
|
n_combined=axes_cfg.combined_omissions,
|
||||||
|
)
|
||||||
|
result = self.vlm.generate_json([self._text_message(prompt)])[0]
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return []
|
||||||
|
ordered_axes = (
|
||||||
|
"synonym_paraphrase",
|
||||||
|
"omit_arm",
|
||||||
|
"omit_orientation",
|
||||||
|
"omit_grasp_method",
|
||||||
|
"combined_omissions",
|
||||||
|
)
|
||||||
|
flat: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for axis in ordered_axes:
|
||||||
|
entries = result.get(axis)
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
continue
|
||||||
|
for item in entries:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
key = item.strip().strip('"').strip("'")
|
||||||
|
if not key or key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
flat.append(key)
|
||||||
|
return flat
|
||||||
|
|
||||||
|
def _episode_video_block(
|
||||||
|
self, record: EpisodeRecord, window: tuple[float, float] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Timestamped contact sheets for the describe / segmentation prompts.
|
||||||
|
|
||||||
|
Always renders the (optionally windowed) episode as contact sheets:
|
||||||
|
frames sampled at ``frames_per_second`` and packed into timestamped
|
||||||
|
JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole
|
||||||
|
episodes that exceed it are windowed upstream in
|
||||||
|
:meth:`_generate_subtasks` so each call stays within budget while the
|
||||||
|
full episode keeps its sampling density.
|
||||||
|
|
||||||
|
When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE
|
||||||
|
(``ts - w0``) to match the window-relative time frame the
|
||||||
|
segmentation prompt works in (spans are offset back to absolute time
|
||||||
|
afterwards).
|
||||||
|
"""
|
||||||
|
if not record.frame_timestamps:
|
||||||
|
return []
|
||||||
|
if window is not None:
|
||||||
|
w0, w1 = float(window[0]), float(window[1])
|
||||||
|
dur = max(0.0, w1 - w0)
|
||||||
|
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
|
||||||
|
n = min(n, self.config.max_frames_per_prompt)
|
||||||
|
if n <= 1 or dur <= 0.0:
|
||||||
|
timestamps = [0.5 * (w0 + w1)]
|
||||||
|
else:
|
||||||
|
step = dur / (n - 1)
|
||||||
|
timestamps = [w0 + i * step for i in range(n)]
|
||||||
|
frames = self.frame_provider.frames_at(record, timestamps)
|
||||||
|
rel = [ts - w0 for ts in timestamps[: len(frames)]]
|
||||||
|
return self._contact_sheet_blocks(frames, rel)
|
||||||
|
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||||
|
n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1)
|
||||||
|
n = min(n, self.config.max_frames_per_prompt)
|
||||||
|
timestamps = self._uniform_episode_timestamps(record, n)
|
||||||
|
frames = self.frame_provider.frames_at(record, timestamps)
|
||||||
|
return self._contact_sheet_blocks(frames, timestamps[: len(frames)])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]:
|
||||||
|
"""``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly."""
|
||||||
|
ts = record.frame_timestamps
|
||||||
|
if n >= len(ts):
|
||||||
|
return [float(t) for t in ts]
|
||||||
|
t0, t_last = float(ts[0]), float(ts[-1])
|
||||||
|
if t_last <= t0 or n <= 1:
|
||||||
|
return [t0] * max(1, n)
|
||||||
|
step = (t_last - t0) / (n - 1)
|
||||||
|
return [t0 + i * step for i in range(n)]
|
||||||
|
|
||||||
|
def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]:
|
||||||
|
"""Build timestamped contact-sheet image blocks from decoded frames."""
|
||||||
|
return to_contact_sheet_blocks(
|
||||||
|
frames,
|
||||||
|
timestamps,
|
||||||
|
columns=self.config.contact_sheet_columns,
|
||||||
|
frames_per_sheet=self.config.contact_sheet_frames_per_sheet,
|
||||||
|
frame_width=self.config.contact_sheet_frame_width,
|
||||||
|
quality=self.config.contact_sheet_quality,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_plan_updates(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
staging: EpisodeStaging,
|
||||||
|
interjection_times: Sequence[float],
|
||||||
|
interjection_texts: Sequence[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||||
|
|
||||||
|
Plans refresh ONLY on user interjections (event-driven). The
|
||||||
|
interjection text is forwarded into the prompt so the refreshed plan
|
||||||
|
reflects the user's correction.
|
||||||
|
"""
|
||||||
|
if not self.config.emit_plan:
|
||||||
|
return
|
||||||
|
existing = staging.read("plan")
|
||||||
|
# Pass the last frame timestamp so the final span is closed (else its
|
||||||
|
# end == start, zero duration, and a refresh inside it is missed).
|
||||||
|
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||||
|
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||||
|
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||||
|
new_rows = list(existing)
|
||||||
|
|
||||||
|
texts: list[str | None] = (
|
||||||
|
[None] * len(interjection_times)
|
||||||
|
if interjection_texts is None
|
||||||
|
else [str(t) if t else None for t in interjection_texts]
|
||||||
|
)
|
||||||
|
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||||
|
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||||
|
if t in already_planned:
|
||||||
|
continue
|
||||||
|
already_planned.add(t)
|
||||||
|
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||||
|
if plan_text is not None:
|
||||||
|
new_rows.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": plan_text,
|
||||||
|
"style": "plan",
|
||||||
|
"timestamp": t,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
staging.write("plan", new_rows)
|
||||||
|
|
||||||
|
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||||
|
"""Generate subtask spans, optionally via a multi-call quality chain.
|
||||||
|
|
||||||
|
Single call (default): watch video → emit subtask JSON.
|
||||||
|
|
||||||
|
Multi-call (opt-in, higher quality, more VLM calls):
|
||||||
|
1. ``subtask_describe_first`` — a grounding pass that narrates
|
||||||
|
ONLY what is visible (no JSON commitment to subtasks yet);
|
||||||
|
its description is injected into the segmentation prompt so
|
||||||
|
the model segments its own grounded observations instead of
|
||||||
|
pattern-matching the task text.
|
||||||
|
2. segmentation — emit subtask JSON (as before).
|
||||||
|
"""
|
||||||
|
if record.row_count == 0 or not record.frame_timestamps:
|
||||||
|
return []
|
||||||
|
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||||
|
effective_task = task if task is not None else record.episode_task
|
||||||
|
|
||||||
|
# ---- Auto-windowing (keeps the full sampling density) --------
|
||||||
|
# Contact sheets are cheap, but a whole long episode sampled at
|
||||||
|
# ``frames_per_second`` can still exceed ``max_frames_per_prompt``.
|
||||||
|
# When it does, split into consecutive windows of exactly that many
|
||||||
|
# frames (one describe→segment call each, still at the full sampling
|
||||||
|
# density), then merge + stitch — so an episode of any length is
|
||||||
|
# covered at full density rather than subsampled into one sparse call.
|
||||||
|
fps = max(1e-6, float(self.config.frames_per_second))
|
||||||
|
n_whole = int(round(episode_duration * fps)) + 1
|
||||||
|
if n_whole > self.config.max_frames_per_prompt:
|
||||||
|
window_s = self.config.max_frames_per_prompt / fps
|
||||||
|
return self._generate_subtasks_windowed(record, effective_task, window_s)
|
||||||
|
|
||||||
|
# ---- Pass 1 (optional): grounding description ----------------
|
||||||
|
observation_block = ""
|
||||||
|
if getattr(self.config, "subtask_describe_first", False):
|
||||||
|
description = self._describe_episode(record, effective_task)
|
||||||
|
if description:
|
||||||
|
observation_block = (
|
||||||
|
"You watched this video and described, chronologically, "
|
||||||
|
"ONLY what the robot actually does:\n"
|
||||||
|
f'"""{description}"""\n\n'
|
||||||
|
"Segment THAT grounded description (cross-checked against "
|
||||||
|
"the video) into atomic subtasks. Do not introduce any "
|
||||||
|
"action that is not in your description above.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Pass 2: segmentation ------------------------------------
|
||||||
|
prompt = self._with_causal_rules(
|
||||||
|
load_prompt("plan_subtasks").format(
|
||||||
|
episode_task=effective_task,
|
||||||
|
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||||
|
max_steps=self.config.plan_max_steps,
|
||||||
|
episode_duration=f"{episode_duration:.3f}",
|
||||||
|
observation_block=observation_block,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||||
|
cleaned = self._clean_spans(spans, record)
|
||||||
|
if not cleaned:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# ---- Full-episode coverage stitch ----------------------------
|
||||||
|
# The VLM can start after t0 or leave gaps, so frames fall through
|
||||||
|
# with no active subtask. Always stitch into a contiguous
|
||||||
|
# [t0, t_last] cover.
|
||||||
|
cleaned = self._stitch_full_coverage(cleaned, record)
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def _generate_subtasks_windowed(
|
||||||
|
self, record: EpisodeRecord, task: str, window_s: float
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Subtask generation in fixed-length windows at constant fps.
|
||||||
|
|
||||||
|
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
|
||||||
|
seconds, runs the describe -> segment chain on each window's own
|
||||||
|
frames (sampled at ``frames_per_second``), offsets
|
||||||
|
each window's spans back to absolute episode time, then merges +
|
||||||
|
stitches into a contiguous whole-episode cover.
|
||||||
|
"""
|
||||||
|
t0 = float(record.frame_timestamps[0])
|
||||||
|
t_last = float(record.frame_timestamps[-1])
|
||||||
|
all_spans: list[dict[str, Any]] = []
|
||||||
|
w0 = t0
|
||||||
|
n_windows = 0
|
||||||
|
while w0 < t_last - 1e-6:
|
||||||
|
w1 = min(w0 + window_s, t_last)
|
||||||
|
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
|
||||||
|
n_windows += 1
|
||||||
|
w0 = w1
|
||||||
|
logger.info(
|
||||||
|
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
|
||||||
|
record.episode_index,
|
||||||
|
n_windows,
|
||||||
|
window_s,
|
||||||
|
len(all_spans),
|
||||||
|
)
|
||||||
|
# Merge across windows: clamp to the absolute episode, sort, and
|
||||||
|
# frame-snap to distinct starts (handles any boundary collisions).
|
||||||
|
cleaned = self._clean_spans(all_spans, record)
|
||||||
|
if not cleaned:
|
||||||
|
return []
|
||||||
|
return self._stitch_full_coverage(cleaned, record)
|
||||||
|
|
||||||
|
def _subtasks_for_window(
|
||||||
|
self, record: EpisodeRecord, task: str, w0: float, w1: float
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Run describe -> segment on one ``[w0, w1]`` window.
|
||||||
|
|
||||||
|
The model works in window-RELATIVE time ``[0, L]`` (it perceives
|
||||||
|
the window as a clip starting at 0); spans are offset back to
|
||||||
|
absolute ``[w0, w1]`` before returning.
|
||||||
|
"""
|
||||||
|
window = (w0, w1)
|
||||||
|
win_len = max(0.0, w1 - w0)
|
||||||
|
|
||||||
|
observation_block = ""
|
||||||
|
if getattr(self.config, "subtask_describe_first", False):
|
||||||
|
description = self._describe_episode(record, task, window=window)
|
||||||
|
if description:
|
||||||
|
observation_block = (
|
||||||
|
"You watched this video clip and described, chronologically, "
|
||||||
|
"ONLY what the robot actually does:\n"
|
||||||
|
f'"""{description}"""\n\n'
|
||||||
|
"Segment THAT grounded description (cross-checked against "
|
||||||
|
"the clip) into atomic subtasks. Do not introduce any "
|
||||||
|
"action that is not in your description above.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = self._with_causal_rules(
|
||||||
|
load_prompt("plan_subtasks").format(
|
||||||
|
episode_task=task,
|
||||||
|
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||||
|
max_steps=self.config.plan_max_steps,
|
||||||
|
episode_duration=f"{win_len:.3f}",
|
||||||
|
observation_block=observation_block,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||||
|
# Window-relative clamp; no frame-snap dedupe yet (done on the
|
||||||
|
# merged absolute set).
|
||||||
|
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
|
||||||
|
if not cleaned:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Offset window-relative spans back to absolute episode time.
|
||||||
|
for s in cleaned:
|
||||||
|
s["start"] = w0 + float(s["start"])
|
||||||
|
s["end"] = w0 + float(s["end"])
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def _stitch_full_coverage(
|
||||||
|
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Make subtask spans tile the full episode with no gaps.
|
||||||
|
|
||||||
|
* The first subtask starts at the episode's first frame ``t0``
|
||||||
|
(any idle / approach before the first labelled action is folded
|
||||||
|
into it), so every early frame has an active subtask.
|
||||||
|
* Each subtask's ``end`` is snapped to the next subtask's
|
||||||
|
``start`` (gaps between spans are closed), and the final
|
||||||
|
subtask's ``end`` extends to the last frame ``t_last``.
|
||||||
|
|
||||||
|
Starts are otherwise left as the (already frame-snapped, distinct)
|
||||||
|
values the VLM produced — only the FIRST start is pulled
|
||||||
|
back to ``t0``, which can't collide with a later span because it
|
||||||
|
was already the earliest. Purely deterministic; runs after the
|
||||||
|
VLM passes.
|
||||||
|
"""
|
||||||
|
if not spans or not record.frame_timestamps:
|
||||||
|
return spans
|
||||||
|
t0 = float(record.frame_timestamps[0])
|
||||||
|
t_last = float(record.frame_timestamps[-1])
|
||||||
|
spans = sorted(spans, key=lambda s: float(s["start"]))
|
||||||
|
spans[0]["start"] = t0
|
||||||
|
for i in range(len(spans) - 1):
|
||||||
|
spans[i]["end"] = float(spans[i + 1]["start"])
|
||||||
|
spans[-1]["end"] = t_last
|
||||||
|
for s in spans:
|
||||||
|
if float(s["end"]) < float(s["start"]):
|
||||||
|
s["end"] = float(s["start"])
|
||||||
|
return spans
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _with_causal_rules(prompt: str) -> str:
|
||||||
|
"""Append the causal event-boundary rules to a describe/segment prompt."""
|
||||||
|
return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}"
|
||||||
|
|
||||||
|
def _clean_spans(
|
||||||
|
self,
|
||||||
|
spans: Any,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
bounds: tuple[float, float] | None = None,
|
||||||
|
dedupe: bool = True,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
|
||||||
|
|
||||||
|
``bounds`` overrides the clamp range — pass the window's
|
||||||
|
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
|
||||||
|
``None`` to clamp to the whole episode ``[t0, t_last]``.
|
||||||
|
``dedupe`` runs the frame-snap distinct-start step; skip it for
|
||||||
|
window-relative spans (frame snapping is done once on the merged,
|
||||||
|
absolute-time set).
|
||||||
|
"""
|
||||||
|
if not spans:
|
||||||
|
return []
|
||||||
|
if bounds is not None:
|
||||||
|
lo, hi = float(bounds[0]), float(bounds[1])
|
||||||
|
else:
|
||||||
|
lo = record.frame_timestamps[0]
|
||||||
|
hi = record.frame_timestamps[-1]
|
||||||
|
cleaned: list[dict[str, Any]] = []
|
||||||
|
for span in spans:
|
||||||
|
try:
|
||||||
|
start = float(span["start"])
|
||||||
|
end = float(span["end"])
|
||||||
|
text = str(span["text"]).strip()
|
||||||
|
except (KeyError, ValueError, TypeError):
|
||||||
|
continue
|
||||||
|
start = max(lo, min(start, hi))
|
||||||
|
end = max(lo, min(end, hi))
|
||||||
|
if end < start:
|
||||||
|
start, end = end, start
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
cleaned.append({"text": text, "start": start, "end": end})
|
||||||
|
cleaned.sort(key=lambda s: s["start"])
|
||||||
|
if dedupe:
|
||||||
|
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def _describe_episode(
|
||||||
|
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Grounding pass: free-form chronological description of the (windowed) video."""
|
||||||
|
prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task))
|
||||||
|
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
|
||||||
|
return text.strip() if isinstance(text, str) and text.strip() else ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dedupe_starts_to_distinct_frames(
|
||||||
|
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Bump same-frame subtask starts onto distinct frames.
|
||||||
|
|
||||||
|
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||||
|
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||||
|
two ``style=subtask`` rows at the identical persistent
|
||||||
|
timestamp. The training-time renderer's ``active_at(t,
|
||||||
|
style=subtask)`` resolver can't disambiguate that and raises
|
||||||
|
``Ambiguous resolver for style='subtask'``.
|
||||||
|
|
||||||
|
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||||
|
if the snapped frame is already taken push the span onto the
|
||||||
|
next unused frame so both subtasks survive on distinct
|
||||||
|
timestamps. If the episode ends before a free frame is found,
|
||||||
|
the trailing span is dropped with a warning — better than
|
||||||
|
poisoning the render.
|
||||||
|
"""
|
||||||
|
if not spans:
|
||||||
|
return spans
|
||||||
|
frames = record.frame_timestamps
|
||||||
|
if not frames:
|
||||||
|
return spans
|
||||||
|
used: set[float] = set()
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for span in spans:
|
||||||
|
ts = snap_to_frame(span["start"], frames)
|
||||||
|
if ts in used:
|
||||||
|
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||||
|
if next_ts is None:
|
||||||
|
logger.warning(
|
||||||
|
"episode %d: subtask %r snapped to occupied frame "
|
||||||
|
"%.3f and no free later frame exists — dropping",
|
||||||
|
record.episode_index,
|
||||||
|
span.get("text"),
|
||||||
|
ts,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
ts = next_ts
|
||||||
|
used.add(ts)
|
||||||
|
new_span = {**span, "start": ts}
|
||||||
|
if float(new_span.get("end", ts)) < ts:
|
||||||
|
new_span["end"] = ts
|
||||||
|
out.append(new_span)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _generate_plan(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||||
|
subtask_spans: Sequence[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
refresh_t: float | None = None,
|
||||||
|
interjection: str | None = None, # noqa: ARG002
|
||||||
|
task: str | None = None, # noqa: ARG002
|
||||||
|
) -> str | None:
|
||||||
|
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||||
|
|
||||||
|
No VLM call: a plain numbered list keeps the plan aligned with the
|
||||||
|
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
|
||||||
|
cost a round-trip per episode/refresh and could diverge).
|
||||||
|
|
||||||
|
1. <subtask 1>
|
||||||
|
2. <subtask 2>
|
||||||
|
|
||||||
|
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
|
||||||
|
interjections, and ``run_episode`` at each boundary), only subtasks
|
||||||
|
starting at or after ``refresh_t`` are included — so it always
|
||||||
|
describes what's left.
|
||||||
|
"""
|
||||||
|
if not subtask_spans:
|
||||||
|
return None
|
||||||
|
remaining = [
|
||||||
|
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||||
|
]
|
||||||
|
if not remaining:
|
||||||
|
# Past the last subtask boundary on a late refresh — nothing
|
||||||
|
# left to plan; emit None so the caller skips the row.
|
||||||
|
return None
|
||||||
|
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
|
||||||
|
|
||||||
|
def _generate_memory(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
prior_memory: str,
|
||||||
|
completed: str,
|
||||||
|
remaining: Sequence[str],
|
||||||
|
*,
|
||||||
|
task: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
prompt = load_prompt("plan_memory").format(
|
||||||
|
episode_task=(task if task is not None else record.episode_task),
|
||||||
|
prior_memory=prior_memory or "(none)",
|
||||||
|
completed_subtask=completed,
|
||||||
|
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||||
|
)
|
||||||
|
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||||
|
return memory.strip() if isinstance(memory, str) else ""
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Prompt templates loaded as plain text.
|
||||||
|
|
||||||
|
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||||
|
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||||
|
plain editors and roundtrip cleanly through ``ruff format``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
def load(name: str) -> str:
|
||||||
|
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||||
|
path = _DIR / f"{name}.txt"
|
||||||
|
return path.read_text(encoding="utf-8")
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
The user just asked the robot: "{episode_task}".
|
||||||
|
|
||||||
|
Generate a short verbal acknowledgement the robot would speak back before
|
||||||
|
beginning the task. Style: compact, confident, friendly.
|
||||||
|
|
||||||
|
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||||
|
"OK, starting with the sponge.", "Got it.".
|
||||||
|
|
||||||
|
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{ "text": "<the spoken acknowledgement>" }}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
You are generating training data for a Hi Robot-style hierarchical
|
||||||
|
robot policy. The robot in this demonstration has ALREADY executed
|
||||||
|
every step shown in the video — we cannot retroactively change the
|
||||||
|
action stream. To keep training data consistent with the video, the
|
||||||
|
"interjection" must align with what the robot is *about to do next* in
|
||||||
|
the demonstration, framed as a natural mid-task user request.
|
||||||
|
|
||||||
|
The episode's overall task: "{episode_task}".
|
||||||
|
|
||||||
|
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||||
|
subtask boundary in the demonstration:
|
||||||
|
|
||||||
|
- Subtask the robot just finished: "{prev_subtask}"
|
||||||
|
- Subtask the robot is about to start: "{next_subtask}"
|
||||||
|
- Time into episode: {timestamp:.2f}s
|
||||||
|
|
||||||
|
Write ONE compact interjection the user would naturally say at this
|
||||||
|
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||||
|
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||||
|
Also write the robot's compact verbal acknowledgement.
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
|
||||||
|
- The interjection MUST be consistent with the next subtask. The user
|
||||||
|
cannot ask for something different from what the robot then does in
|
||||||
|
the video. If you're tempted to say "actually skip X" or "do Y
|
||||||
|
instead", DO NOT — those would contradict the demonstration.
|
||||||
|
- The interjection must reference an object, location, or action that
|
||||||
|
is plausible given the visible scene and the next subtask text.
|
||||||
|
- One short phrase or sentence each. Conversational, not robotic.
|
||||||
|
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||||
|
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||||
|
|
||||||
|
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||||
|
- "Now go ahead and {next_subtask}."
|
||||||
|
- "Great, can you {next_subtask} next?"
|
||||||
|
- "{next_subtask}, please."
|
||||||
|
- "Before you continue, please {next_subtask}."
|
||||||
|
- "Looking good — {next_subtask} now."
|
||||||
|
- "Okay, {next_subtask}."
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{
|
||||||
|
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||||
|
"speech": "<short robot acknowledgement>"
|
||||||
|
}}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
You are updating the robot's compressed semantic memory at the boundary of
|
||||||
|
a completed subtask.
|
||||||
|
|
||||||
|
Reference (verbatim from MEM, Torne 2026):
|
||||||
|
"Remove or compress information in the language memory whenever
|
||||||
|
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||||
|
task execution. Specific object attributes (colors, precise quantities of
|
||||||
|
each item) get discarded when their details won't affect subsequent
|
||||||
|
actions. Functional outcomes (where items went, how many) are preserved."
|
||||||
|
|
||||||
|
Episode task: "{episode_task}"
|
||||||
|
Previous memory: {prior_memory}
|
||||||
|
Just-completed subtask: "{completed_subtask}"
|
||||||
|
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||||
|
|
||||||
|
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||||
|
robot has accomplished so far — the running story it would tell itself.
|
||||||
|
|
||||||
|
Authoring rules:
|
||||||
|
- First person, past tense. Every sentence starts with "I": "I picked
|
||||||
|
up...", "I opened...", "I moved to...".
|
||||||
|
- One or two short sentences. Extend the previous memory with the
|
||||||
|
just-completed subtask; do not rewrite it from scratch.
|
||||||
|
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||||
|
drop HOW (grasp details, motions).
|
||||||
|
- Compress completed steps and drop object attributes (colors, exact
|
||||||
|
counts) once they no longer affect the remaining subtasks.
|
||||||
|
|
||||||
|
Example (MEM, Torne 2026):
|
||||||
|
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||||
|
moved to the drawer."
|
||||||
|
After: "I prepared the pot and got the ingredients. I opened the
|
||||||
|
drawer with the masher."
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
You are watching a teleoperated robot demonstration from a single
|
||||||
|
camera. The user asked the robot to: "{episode_task}"
|
||||||
|
|
||||||
|
This is an OBSERVATION pass. Watch the entire clip and describe, in
|
||||||
|
chronological order, ONLY what the robot physically does — the concrete
|
||||||
|
motions, approaches, contacts, grasps, releases, and relocations you can
|
||||||
|
actually SEE in the frames.
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
- Describe only motion visible in the video. Do NOT use the task
|
||||||
|
instruction to guess steps that aren't shown. The instruction is the
|
||||||
|
goal; the video is ground truth.
|
||||||
|
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
|
||||||
|
the single field below. Just narrate what happens.
|
||||||
|
- Give an approximate timestamp (in seconds) for each distinct event,
|
||||||
|
e.g. "0.0-1.4s: the base drives forward toward the stove".
|
||||||
|
- Do NOT invent objects, grasps, destinations, or steps. If the robot
|
||||||
|
only does one thing (e.g. it just navigates and the clip ends), say
|
||||||
|
exactly that and nothing more.
|
||||||
|
- Be concrete and literal. "the gripper closes on the mug" — not "the
|
||||||
|
robot prepares to make coffee".
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"description": "<chronological, timestamped description of ONLY what is visible>"
|
||||||
|
}}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
You are labeling a teleoperated robot demonstration.
|
||||||
|
|
||||||
|
The user originally asked: "{episode_task}"
|
||||||
|
|
||||||
|
You are shown the entire demonstration as a single video. Watch the
|
||||||
|
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||||
|
the robot performs.
|
||||||
|
|
||||||
|
{observation_block}GROUNDING — read this first, it overrides everything below:
|
||||||
|
- Label ONLY what the robot actually does in the video. Every subtask
|
||||||
|
you emit must correspond to motion you can SEE in specific frames.
|
||||||
|
- Do NOT invent, anticipate, or pad. If the robot only does one thing
|
||||||
|
(e.g. it just navigates to a location and the clip ends), emit
|
||||||
|
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
|
||||||
|
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
|
||||||
|
subtasks than the ceiling is not just allowed, it is expected for
|
||||||
|
short / atomic demonstrations. One correct subtask is far better
|
||||||
|
than several invented ones.
|
||||||
|
- If the video does not clearly show the action implied by the task,
|
||||||
|
describe what you actually see — do NOT fabricate the task's steps
|
||||||
|
from the instruction text. The instruction tells you the goal; the
|
||||||
|
VIDEO is the ground truth for what happened.
|
||||||
|
|
||||||
|
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||||
|
|
||||||
|
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||||
|
execute end-to-end. A "skill" bundles its own approach motion with
|
||||||
|
its terminal action — do NOT split the approach off as its own
|
||||||
|
subtask. The whole-arm policy already learns to reach as part of
|
||||||
|
every manipulation primitive.
|
||||||
|
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||||
|
these verbs (extend only when none fits):
|
||||||
|
pick up <obj> — approach + grasp + lift in one subtask
|
||||||
|
put <obj> on/in <loc> — transport + release in one subtask
|
||||||
|
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||||
|
push <obj> — contact + linear shove
|
||||||
|
pull <obj> — contact + linear retract
|
||||||
|
turn <knob/dial/handle> — rotary actuation
|
||||||
|
press <button> — single-press contact
|
||||||
|
open <drawer/door/lid> — full open motion
|
||||||
|
close <drawer/door/lid> — full close motion
|
||||||
|
pour <src> into <dst> — tilt + flow
|
||||||
|
insert <obj> into <slot>— alignment + push-fit
|
||||||
|
go to <loc> — ONLY when no grasp / actuation follows
|
||||||
|
(e.g. a pure relocation between phases).
|
||||||
|
If the next subtask grasps something at
|
||||||
|
that location, drop "go to ..." and just
|
||||||
|
write "pick up ..." instead.
|
||||||
|
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||||
|
as standalone subtasks; fold them into the parent composite:
|
||||||
|
"move to X" → fold into "pick up X" (or whatever follows)
|
||||||
|
"reach for X" → fold into "pick up X"
|
||||||
|
"grasp X" → fold into "pick up X"
|
||||||
|
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||||
|
the transport phase of a place)
|
||||||
|
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||||
|
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||||
|
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||||
|
detail (which hand, which grasp point) ONLY when it is needed to
|
||||||
|
disambiguate. Every subtask must begin with one of the verbs
|
||||||
|
above (no leading nouns, no "then", no "first").
|
||||||
|
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||||
|
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||||
|
do not describe it.
|
||||||
|
- Use the exact object nouns from the task above. If the task says
|
||||||
|
"cube", every subtask says "cube" — never switch to "block". If it
|
||||||
|
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||||
|
consistent across the whole episode.
|
||||||
|
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||||
|
"turn red knob", "press start button", "go to sink".
|
||||||
|
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||||
|
must be folded into "pick up blue cube"); "the robot arm moves
|
||||||
|
towards the blue cube" (third person, too long); "carefully pick
|
||||||
|
up the cube" (adverb, article); "release the yellow block"
|
||||||
|
("block" when the task said "cube", and "release" must be folded
|
||||||
|
into a "put"/"place" subtask).
|
||||||
|
- Subtasks are non-overlapping and cover the full episode in order.
|
||||||
|
Choose the cut points yourself based on what you see in the video
|
||||||
|
(gripper open/close events, contact, regrasps, transitions).
|
||||||
|
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||||
|
candidate span would be shorter, merge it into its neighbour
|
||||||
|
rather than emitting it.
|
||||||
|
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||||
|
are preferred over many micro-steps.
|
||||||
|
- Every subtask's [start_time, end_time] must lie within
|
||||||
|
[0.0, {episode_duration}] seconds.
|
||||||
|
|
||||||
|
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
|
||||||
|
fires ONLY on the spatial situation it names; it must not change how you
|
||||||
|
label any other situation):
|
||||||
|
- STACK vs PUT: if an object is placed ON TOP OF another specific object
|
||||||
|
(not on a flat table / shelf / counter), use "stack ... on ...", not
|
||||||
|
"put". "stack blue book on green book", NOT "put blue book on table".
|
||||||
|
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
|
||||||
|
receptacle (push-fit), use "insert ... into ...", not "put".
|
||||||
|
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
|
||||||
|
on the object and the object moves WITH the hand, it is "pick up" /
|
||||||
|
"retrieve" (object leaves its location). If the gripper OPENS and the
|
||||||
|
object stays where the hand left it, it is "put" / "place" (object
|
||||||
|
arrives at a location). Decide by which way the object moves, not by
|
||||||
|
where the hand ends up.
|
||||||
|
- POUR vs PUT: only use "pour" when the source is tilted and contents
|
||||||
|
flow out; moving a full container without tilting is "put"/"place".
|
||||||
|
|
||||||
|
Output strictly valid JSON of shape:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"subtasks": [
|
||||||
|
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
You are generating structured augmentations of a robot task instruction
|
||||||
|
for training a language-conditioned policy. Unlike free-form rephrasing,
|
||||||
|
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
|
||||||
|
a specific element of the task while preserving its meaning.
|
||||||
|
|
||||||
|
Original task: "{base_task}"
|
||||||
|
|
||||||
|
Produce variants along five named axes. Each axis has a target count.
|
||||||
|
The whole batch should expose the policy to maximum linguistic diversity
|
||||||
|
WITHOUT changing what the robot is supposed to do.
|
||||||
|
|
||||||
|
Axes and target counts:
|
||||||
|
|
||||||
|
synonym_paraphrase ({n_synonym}):
|
||||||
|
Different wording / verbs / sentence structure. ALL information
|
||||||
|
from the original task is preserved — same object, same arm
|
||||||
|
specification if present, same orientation if present, same grasp
|
||||||
|
if present.
|
||||||
|
|
||||||
|
omit_arm ({n_omit_arm}):
|
||||||
|
Drop the left/right/both arm specification from the task. Skip
|
||||||
|
entirely (emit 0 entries) if the original task does NOT mention an
|
||||||
|
arm. Do not invent an arm specification just to omit it.
|
||||||
|
|
||||||
|
omit_orientation ({n_omit_orientation}):
|
||||||
|
Drop orientation cues (upright, sideways, facing the user,
|
||||||
|
long-edge-first, etc.). Skip entirely if no orientation cue is
|
||||||
|
present in the original task.
|
||||||
|
|
||||||
|
omit_grasp_method ({n_omit_grasp_method}):
|
||||||
|
Drop the grip / grasp method specification (pinch, wrap, hold by
|
||||||
|
the rim, etc.). Skip entirely if no grasp method is mentioned.
|
||||||
|
|
||||||
|
combined_omissions ({n_combined}):
|
||||||
|
Combine TWO of the above omissions simultaneously (e.g. drop both
|
||||||
|
arm and orientation). Skip entirely if fewer than two of (arm,
|
||||||
|
orientation, grasp_method) appear in the original task.
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
- Each variant MUST preserve the core action, the target object, AND
|
||||||
|
the goal / destination. Do not change which object is involved, where
|
||||||
|
it goes, or the high-level action. "Navigate to the stove" may become
|
||||||
|
"go to the stove" or "head over to the stove" — it must NEVER become
|
||||||
|
"wander around the kitchen", "explore the room", or anything that
|
||||||
|
drops or generalises the stove destination. If you cannot vary the
|
||||||
|
wording without changing the goal, emit fewer variants.
|
||||||
|
- Only the FIVE listed elements (wording, arm, orientation, grasp
|
||||||
|
method, or a combination) may be varied or omitted. The verb's
|
||||||
|
meaning, the object, and the destination are fixed.
|
||||||
|
- Each variant is plain prose, no markdown, no quotes, no list numbers.
|
||||||
|
- Each variant must be DISTINCT from every other variant in the entire
|
||||||
|
output, both within and across axes. Near-duplicates are not allowed.
|
||||||
|
- If an axis cannot reach its target count because the original task
|
||||||
|
lacks the omittable element, emit fewer entries — do NOT pad the
|
||||||
|
axis with paraphrases that belong to a different axis.
|
||||||
|
- Variants should not all start with verbs — vary sentence structure
|
||||||
|
(some imperative, some polite request, some question).
|
||||||
|
|
||||||
|
Output strictly valid JSON of shape:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"synonym_paraphrase": ["<v1>", "<v2>", ...],
|
||||||
|
"omit_arm": ["<v1>", "<v2>", ...],
|
||||||
|
"omit_orientation": ["<v1>", ...],
|
||||||
|
"omit_grasp_method": ["<v1>", ...],
|
||||||
|
"combined_omissions": ["<v1>", ...]
|
||||||
|
}}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
You are generating training data for a Hi Robot-style policy. We need
|
||||||
|
{n} alternative phrasings of the same robot task so the policy sees
|
||||||
|
diverse user prompts during training instead of the same canonical
|
||||||
|
string repeated every frame.
|
||||||
|
|
||||||
|
Original task:
|
||||||
|
"{base_task}"
|
||||||
|
|
||||||
|
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||||
|
|
||||||
|
- formality (casual / polite / curt)
|
||||||
|
- verbosity (mostly short imperative; occasional polite request)
|
||||||
|
- word choice (synonyms, different verbs)
|
||||||
|
- sentence structure (imperative / question / suggestion)
|
||||||
|
|
||||||
|
Hard rules:
|
||||||
|
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||||
|
Do not change which object is involved, the destination, or the
|
||||||
|
action. Do not add extra steps. Do not invent new objects.
|
||||||
|
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||||
|
markdown, no quotes, no list numbers.
|
||||||
|
- Phrasings must be distinct — no near-duplicates.
|
||||||
|
- Output exactly {n} entries.
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{
|
||||||
|
"rephrasings": [
|
||||||
|
"<phrasing 1>",
|
||||||
|
"<phrasing 2>",
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
The video above shows a robot manipulation episode in full. Look at
|
||||||
|
the entire video and describe in ONE concise sentence what the robot
|
||||||
|
is doing.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- One sentence, in natural English, like a user instruction.
|
||||||
|
- Capture the goal of the demonstration, not low-level motions.
|
||||||
|
Example: "place the yellow cube into the red bin" — not "move the
|
||||||
|
end-effector down 5cm and close the gripper".
|
||||||
|
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||||
|
- Do not invent objects or actions that aren't visible.
|
||||||
|
- Do not output anything other than the JSON object below.
|
||||||
|
|
||||||
|
Output strictly valid JSON:
|
||||||
|
{{
|
||||||
|
"task": "<single concise sentence describing what the robot does in this video>"
|
||||||
|
}}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
You are generating a frame-grounded visual question/answer pair for
|
||||||
|
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||||
|
Policies — both train policies on grounded features such as bounding box
|
||||||
|
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||||
|
|
||||||
|
The frame shows a robot working on: "{episode_task}".
|
||||||
|
|
||||||
|
Question types and the EXACT answer JSON shape required for each:
|
||||||
|
|
||||||
|
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||||
|
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||||
|
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||||
|
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||||
|
|
||||||
|
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||||
|
"point": [x, y]}}
|
||||||
|
|
||||||
|
count => {{"label": "<obj>", "count": <int>,
|
||||||
|
"note": "<optional short note>"}}
|
||||||
|
|
||||||
|
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||||
|
"value": "<observed value>"}}
|
||||||
|
|
||||||
|
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||||
|
"above|below|near>", "object": "<obj>"}}
|
||||||
|
|
||||||
|
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"question": "<short, frame-grounded question>",
|
||||||
|
"answer": <object whose shape matches the schema above>
|
||||||
|
}}
|
||||||
@@ -0,0 +1,216 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Datatrove-shaped reader.
|
||||||
|
|
||||||
|
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||||
|
episode containing:
|
||||||
|
|
||||||
|
- ``episode_index``: int
|
||||||
|
- ``frame_timestamps``: tuple[float, ...]
|
||||||
|
- ``frame_indices``: tuple[int, ...]
|
||||||
|
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||||
|
- ``data_path``: pathlib.Path of the source parquet shard
|
||||||
|
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||||
|
|
||||||
|
This shape lets each module operate per-episode without loading all parquet
|
||||||
|
rows into memory at once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from lerobot.datasets.io_utils import load_tasks
|
||||||
|
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EpisodeRecord:
|
||||||
|
"""Per-episode record yielded by the reader."""
|
||||||
|
|
||||||
|
episode_index: int
|
||||||
|
episode_task: str
|
||||||
|
frame_timestamps: tuple[float, ...]
|
||||||
|
frame_indices: tuple[int, ...]
|
||||||
|
data_path: Path
|
||||||
|
row_offset: int # row offset within the parquet file where this episode starts
|
||||||
|
row_count: int # number of rows for this episode
|
||||||
|
|
||||||
|
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||||
|
# repeat queries from different modules don't re-read the whole shard.
|
||||||
|
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||||
|
|
||||||
|
def frames_df(self): # type: ignore[no-untyped-def]
|
||||||
|
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||||
|
if self._frames_df_cache is None:
|
||||||
|
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||||
|
|
||||||
|
table = pq.read_table(self.data_path)
|
||||||
|
df: pd.DataFrame = table.to_pandas()
|
||||||
|
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||||
|
drop=True
|
||||||
|
)
|
||||||
|
return self._frames_df_cache
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_subtask_spans(
|
||||||
|
rows: Sequence[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
episode_end_t: float | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||||
|
|
||||||
|
Each span's ``end`` is the next span's ``start``. The final span's
|
||||||
|
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||||
|
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||||
|
which is what downstream consumers (memory, interjection boundary
|
||||||
|
selection) expect.
|
||||||
|
|
||||||
|
Used by the ``plan`` module (plan-update pass) and the
|
||||||
|
``interjections`` module (interjection anchoring), which both need the
|
||||||
|
same span shape.
|
||||||
|
"""
|
||||||
|
sorted_rows = sorted(
|
||||||
|
(r for r in rows if r.get("style") == "subtask"),
|
||||||
|
key=lambda r: float(r["timestamp"]),
|
||||||
|
)
|
||||||
|
spans: list[dict[str, Any]] = []
|
||||||
|
for r in sorted_rows:
|
||||||
|
t = float(r["timestamp"])
|
||||||
|
if spans:
|
||||||
|
spans[-1]["end"] = t
|
||||||
|
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||||
|
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||||
|
spans[-1]["end"] = float(episode_end_t)
|
||||||
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||||
|
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||||
|
|
||||||
|
Modules use this when emitting event-style rows so the row's
|
||||||
|
timestamp matches a real parquet frame: event rows must land on an
|
||||||
|
exact frame, otherwise the per-frame event lookup the writer does
|
||||||
|
would never match them.
|
||||||
|
"""
|
||||||
|
if not frame_timestamps:
|
||||||
|
return float(t)
|
||||||
|
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||||
|
return float(nearest)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||||
|
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||||
|
|
||||||
|
Returns an empty dict when the file is absent — the task description is
|
||||||
|
derived later from the video if needed. Reuses the library-level
|
||||||
|
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||||
|
frame indexed by task string with a ``task_index`` column.
|
||||||
|
"""
|
||||||
|
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||||
|
return {}
|
||||||
|
tasks = load_tasks(root)
|
||||||
|
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||||
|
|
||||||
|
|
||||||
|
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||||
|
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||||
|
|
||||||
|
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||||
|
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||||
|
under ``data/`` and groups by ``episode_index``.
|
||||||
|
"""
|
||||||
|
tasks = _load_tasks_lookup(root)
|
||||||
|
data_dir = root / "data"
|
||||||
|
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||||
|
|
||||||
|
only_set = set(only_episodes) if only_episodes is not None else None
|
||||||
|
|
||||||
|
for path in parquet_files:
|
||||||
|
yield from _iter_one_path(path, tasks, only_set)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||||
|
table = pq.read_table(path)
|
||||||
|
names = table.column_names
|
||||||
|
if "episode_index" not in names:
|
||||||
|
return
|
||||||
|
episode_col = table.column("episode_index").to_pylist()
|
||||||
|
timestamp_col = (
|
||||||
|
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||||
|
)
|
||||||
|
frame_col = (
|
||||||
|
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||||
|
)
|
||||||
|
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||||
|
|
||||||
|
def _build(
|
||||||
|
ep: int,
|
||||||
|
start: int,
|
||||||
|
end: int,
|
||||||
|
task_idx: int | None,
|
||||||
|
ts_buf: list[float],
|
||||||
|
fi_buf: list[int],
|
||||||
|
) -> EpisodeRecord | None:
|
||||||
|
if only_set is not None and ep not in only_set:
|
||||||
|
return None
|
||||||
|
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||||
|
return EpisodeRecord(
|
||||||
|
episode_index=ep,
|
||||||
|
episode_task=task,
|
||||||
|
frame_timestamps=tuple(ts_buf),
|
||||||
|
frame_indices=tuple(fi_buf),
|
||||||
|
data_path=path,
|
||||||
|
row_offset=start,
|
||||||
|
row_count=end - start,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_ep: int | None = None
|
||||||
|
start_offset = 0
|
||||||
|
ts_buf: list[float] = []
|
||||||
|
fi_buf: list[int] = []
|
||||||
|
cur_task_idx: int | None = None
|
||||||
|
|
||||||
|
for i, ep in enumerate(episode_col):
|
||||||
|
if cur_ep is None:
|
||||||
|
cur_ep = ep
|
||||||
|
start_offset = i
|
||||||
|
ts_buf = [timestamp_col[i]]
|
||||||
|
fi_buf = [frame_col[i]]
|
||||||
|
cur_task_idx = task_col[i] if task_col is not None else None
|
||||||
|
continue
|
||||||
|
if ep != cur_ep:
|
||||||
|
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||||
|
if rec is not None:
|
||||||
|
yield rec
|
||||||
|
cur_ep = ep
|
||||||
|
start_offset = i
|
||||||
|
ts_buf = [timestamp_col[i]]
|
||||||
|
fi_buf = [frame_col[i]]
|
||||||
|
cur_task_idx = task_col[i] if task_col is not None else None
|
||||||
|
else:
|
||||||
|
ts_buf.append(timestamp_col[i])
|
||||||
|
fi_buf.append(frame_col[i])
|
||||||
|
|
||||||
|
if cur_ep is not None:
|
||||||
|
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||||
|
if rec is not None:
|
||||||
|
yield rec
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Per-episode staging.
|
||||||
|
|
||||||
|
Each module writes its raw output as a JSONL file under
|
||||||
|
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||||
|
staging tree and partitions rows into the two language columns.
|
||||||
|
|
||||||
|
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||||
|
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||||
|
appended to. The final dataset format is parquet; staging is just an
|
||||||
|
intermediate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
ModuleName = str
|
||||||
|
|
||||||
|
_MODULES: tuple[ModuleName, ...] = (
|
||||||
|
"plan",
|
||||||
|
"interjections",
|
||||||
|
"vqa",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EpisodeStaging:
|
||||||
|
"""Filesystem layout for a single episode's staged module outputs."""
|
||||||
|
|
||||||
|
root: Path
|
||||||
|
episode_index: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episode_dir(self) -> Path:
|
||||||
|
return self.root / f"episode_{self.episode_index:06d}"
|
||||||
|
|
||||||
|
def path_for(self, module: ModuleName) -> Path:
|
||||||
|
if module not in _MODULES:
|
||||||
|
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||||
|
return self.episode_dir / f"{module}.jsonl"
|
||||||
|
|
||||||
|
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||||
|
path = self.path_for(module)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Atomic replace: a crash mid-write would otherwise leave a
|
||||||
|
# half-written JSONL file that ``read()`` would then fail to
|
||||||
|
# parse. Write to a sibling .tmp and rename so the target path
|
||||||
|
# only ever points at a complete file.
|
||||||
|
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
with tmp_path.open("w", encoding="utf-8") as f:
|
||||||
|
for row in rows:
|
||||||
|
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||||
|
f.write("\n")
|
||||||
|
tmp_path.replace(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||||
|
path = self.path_for(module)
|
||||||
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
with path.open(encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
out.append(json.loads(line))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||||
|
return {m: self.read(m) for m in _MODULES}
|
||||||
|
|
||||||
|
def has(self, module: ModuleName) -> bool:
|
||||||
|
return self.path_for(module).exists()
|
||||||
@@ -0,0 +1,332 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Pre-write validation against staged outputs.
|
||||||
|
|
||||||
|
Runs after all three modules have written their per-episode artifacts but
|
||||||
|
*before* the writer rewrites parquet shards. The validator never touches
|
||||||
|
parquet; it only inspects the staging tree and the source frame timestamps
|
||||||
|
exposed by :class:`EpisodeRecord`.
|
||||||
|
|
||||||
|
Checks (per the plan's "Intermediate staging and validation" section):
|
||||||
|
|
||||||
|
- exact timestamp alignment against source frame timestamps
|
||||||
|
- no orphan speech / interjection pairs
|
||||||
|
- plan / memory emission consistency (events have a paired persistent row)
|
||||||
|
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||||
|
attribute / spatial)
|
||||||
|
- every row maps to its correct column under :func:`column_for_style`
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.datasets.language import (
|
||||||
|
LANGUAGE_EVENTS,
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
column_for_style,
|
||||||
|
is_view_dependent_style,
|
||||||
|
validate_camera_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .reader import EpisodeRecord
|
||||||
|
from .staging import EpisodeStaging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationReport:
|
||||||
|
"""Outcome of one validation pass across all episodes."""
|
||||||
|
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
warnings: list[str] = field(default_factory=list)
|
||||||
|
episodes_checked: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ok(self) -> bool:
|
||||||
|
return not self.errors
|
||||||
|
|
||||||
|
def add_error(self, message: str) -> None:
|
||||||
|
self.errors.append(message)
|
||||||
|
|
||||||
|
def add_warning(self, message: str) -> None:
|
||||||
|
self.warnings.append(message)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||||
|
|
||||||
|
|
||||||
|
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||||
|
"bbox": {"detections"},
|
||||||
|
"keypoint": {"label", "point_format", "point"},
|
||||||
|
"count": {"label", "count"},
|
||||||
|
"attribute": {"label", "attribute", "value"},
|
||||||
|
"spatial": {"subject", "relation", "object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def classify_vqa_answer(payload: Any) -> str | None:
|
||||||
|
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
keys = set(payload.keys())
|
||||||
|
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||||
|
if required.issubset(keys):
|
||||||
|
return kind
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StagingValidator:
|
||||||
|
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||||
|
|
||||||
|
timestamp_atol: float = 0.0 # exact-match by default
|
||||||
|
dataset_camera_keys: tuple[str, ...] | None = None
|
||||||
|
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||||
|
validator additionally enforces that every view-dependent row's
|
||||||
|
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||||
|
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||||
|
|
||||||
|
def validate(
|
||||||
|
self,
|
||||||
|
records: Sequence[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
) -> ValidationReport:
|
||||||
|
report = ValidationReport()
|
||||||
|
for record in records:
|
||||||
|
self._validate_episode(record, staging_dir, report)
|
||||||
|
report.episodes_checked += 1
|
||||||
|
return report
|
||||||
|
|
||||||
|
def _validate_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
staging_dir: Path,
|
||||||
|
report: ValidationReport,
|
||||||
|
) -> None:
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
staged = staging.read_all()
|
||||||
|
all_rows: list[dict[str, Any]] = []
|
||||||
|
for module_name, rows in staged.items():
|
||||||
|
for row in rows:
|
||||||
|
row = {**row, "_module": module_name}
|
||||||
|
all_rows.append(row)
|
||||||
|
|
||||||
|
frame_ts = set(record.frame_timestamps)
|
||||||
|
|
||||||
|
events: list[dict[str, Any]] = []
|
||||||
|
persistent: list[dict[str, Any]] = []
|
||||||
|
for row in all_rows:
|
||||||
|
self._check_column_routing(row, report, record.episode_index)
|
||||||
|
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
|
||||||
|
# ``_check_column_routing`` already recorded any unknown-style error;
|
||||||
|
# don't let the same ``column_for_style`` lookup raise here uncaught.
|
||||||
|
try:
|
||||||
|
column = column_for_style(row.get("style"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if column == LANGUAGE_PERSISTENT:
|
||||||
|
persistent.append(row)
|
||||||
|
else:
|
||||||
|
events.append(row)
|
||||||
|
|
||||||
|
for row in events:
|
||||||
|
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||||
|
|
||||||
|
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||||
|
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||||
|
self._check_vqa_json(events, report, record.episode_index)
|
||||||
|
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||||
|
|
||||||
|
def _check_camera_field(
|
||||||
|
self,
|
||||||
|
row: dict[str, Any],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
dataset_camera_keys: Sequence[str] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||||
|
style = row.get("style")
|
||||||
|
camera = row.get("camera")
|
||||||
|
try:
|
||||||
|
validate_camera_field(style, camera)
|
||||||
|
except ValueError as exc:
|
||||||
|
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
|
||||||
|
return
|
||||||
|
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||||
|
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_vqa_uniqueness_per_frame_camera(
|
||||||
|
self,
|
||||||
|
events: Iterable[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||||
|
counts: dict[tuple[float, str, str], int] = {}
|
||||||
|
for row in events:
|
||||||
|
if row.get("style") != "vqa":
|
||||||
|
continue
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
camera = row.get("camera")
|
||||||
|
role = row.get("role")
|
||||||
|
if ts is None or camera is None or role is None:
|
||||||
|
continue # other validators flag these
|
||||||
|
key = (float(ts), str(camera), str(role))
|
||||||
|
counts[key] = counts.get(key, 0) + 1
|
||||||
|
for (ts, camera, role), n in counts.items():
|
||||||
|
if n > 1:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||||
|
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_column_routing(
|
||||||
|
self,
|
||||||
|
row: dict[str, Any],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
style = row.get("style")
|
||||||
|
module = row.get("_module")
|
||||||
|
try:
|
||||||
|
target_col = column_for_style(style)
|
||||||
|
except ValueError:
|
||||||
|
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||||
|
return
|
||||||
|
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||||
|
)
|
||||||
|
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_event_timestamp_alignment(
|
||||||
|
self,
|
||||||
|
row: dict[str, Any],
|
||||||
|
frame_ts: set[float],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
if ts is None:
|
||||||
|
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||||
|
return
|
||||||
|
if self.timestamp_atol == 0.0:
|
||||||
|
if float(ts) not in frame_ts:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_speech_interjection_pairs(
|
||||||
|
self,
|
||||||
|
events: Iterable[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
speech_ts: dict[float, int] = {}
|
||||||
|
interjection_ts: dict[float, int] = {}
|
||||||
|
for row in events:
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
if ts is None:
|
||||||
|
continue
|
||||||
|
ts_f = float(ts)
|
||||||
|
if row.get("style") is None and row.get("role") == "assistant":
|
||||||
|
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||||
|
if row.get("style") == "interjection":
|
||||||
|
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||||
|
|
||||||
|
for ts in interjection_ts:
|
||||||
|
if ts not in speech_ts:
|
||||||
|
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||||
|
|
||||||
|
def _check_plan_memory_consistency(
|
||||||
|
self,
|
||||||
|
persistent: Sequence[dict[str, Any]],
|
||||||
|
events: Sequence[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||||
|
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||||
|
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||||
|
interjection_ts = sorted(
|
||||||
|
{
|
||||||
|
float(r["timestamp"])
|
||||||
|
for r in events
|
||||||
|
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if persistent and not plan_ts:
|
||||||
|
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||||
|
# every interjection should have a same-timestamp plan refresh
|
||||||
|
for ts in interjection_ts:
|
||||||
|
if ts not in set(plan_ts):
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||||
|
)
|
||||||
|
# memory should be emitted at subtask boundaries (subset relation)
|
||||||
|
if memory_ts and subtask_ts:
|
||||||
|
mem_set = set(memory_ts)
|
||||||
|
sub_set = set(subtask_ts)
|
||||||
|
stray = sorted(mem_set - sub_set)
|
||||||
|
if stray:
|
||||||
|
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||||
|
|
||||||
|
def _check_vqa_json(
|
||||||
|
self,
|
||||||
|
events: Iterable[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
for row in events:
|
||||||
|
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||||
|
continue
|
||||||
|
content = row.get("content")
|
||||||
|
if content is None:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = json.loads(content)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
shape = classify_vqa_answer(payload)
|
||||||
|
if shape is None:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||||
|
)
|
||||||
@@ -0,0 +1,617 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Shared Qwen-VL client.
|
||||||
|
|
||||||
|
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||||
|
available (high throughput, JSON-guided decoding); transformers is the
|
||||||
|
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||||
|
into a real model.
|
||||||
|
|
||||||
|
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||||
|
|
||||||
|
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||||
|
- requests JSON output from the server,
|
||||||
|
- batches requests transparently,
|
||||||
|
- and reprompts once on a JSON parse failure with an inline correction
|
||||||
|
message before raising.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from .config import VlmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VlmClient(Protocol):
|
||||||
|
"""Protocol every backend must implement."""
|
||||||
|
|
||||||
|
def generate_json(
|
||||||
|
self,
|
||||||
|
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
max_new_tokens: int | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Generate one JSON-decoded response per messages list."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StubVlmClient:
|
||||||
|
"""Deterministic stub used in unit tests.
|
||||||
|
|
||||||
|
A test passes a callable that maps the *last user message text* (or, if
|
||||||
|
that is empty, the full message list) to a JSON-serializable response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||||
|
|
||||||
|
def generate_json(
|
||||||
|
self,
|
||||||
|
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
max_new_tokens: int | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return [self.responder(list(messages)) for messages in messages_batch]
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_to_json(text: str) -> Any:
|
||||||
|
text = text.strip()
|
||||||
|
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||||
|
while "<think>" in text and "</think>" in text:
|
||||||
|
start = text.find("<think>")
|
||||||
|
end = text.find("</think>", start) + len("</think>")
|
||||||
|
text = (text[:start] + text[end:]).strip()
|
||||||
|
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||||
|
if text.startswith("```"):
|
||||||
|
first = text.find("\n")
|
||||||
|
last = text.rfind("```")
|
||||||
|
if first != -1 and last != -1 and last > first:
|
||||||
|
text = text[first + 1 : last].strip()
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
# Fall back to extracting the first balanced {...} block.
|
||||||
|
obj_text = _extract_first_json_object(text)
|
||||||
|
if obj_text is None:
|
||||||
|
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||||
|
return json.loads(obj_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_first_json_object(text: str) -> str | None:
|
||||||
|
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||||
|
string literals. Returns ``None`` if no balanced block is found."""
|
||||||
|
start = text.find("{")
|
||||||
|
if start < 0:
|
||||||
|
return None
|
||||||
|
depth = 0
|
||||||
|
in_string = False
|
||||||
|
escape = False
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
ch = text[i]
|
||||||
|
if escape:
|
||||||
|
escape = False
|
||||||
|
continue
|
||||||
|
if ch == "\\":
|
||||||
|
escape = True
|
||||||
|
continue
|
||||||
|
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||||
|
# above already handled and reset it.
|
||||||
|
if ch == '"':
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
if in_string:
|
||||||
|
continue
|
||||||
|
if ch == "{":
|
||||||
|
depth += 1
|
||||||
|
elif ch == "}":
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return text[start : i + 1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _GenericTextClient:
|
||||||
|
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||||
|
|
||||||
|
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||||
|
config: VlmConfig
|
||||||
|
|
||||||
|
def generate_json(
|
||||||
|
self,
|
||||||
|
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
max_new_tokens: int | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||||
|
temp = temperature if temperature is not None else self.config.temperature
|
||||||
|
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||||
|
out: list[Any] = []
|
||||||
|
for messages, text in zip(messages_batch, raw, strict=True):
|
||||||
|
try:
|
||||||
|
out.append(_strip_to_json(text))
|
||||||
|
continue
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
retry = list(messages) + [
|
||||||
|
{"role": "assistant", "content": text},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Your previous reply was not valid JSON. "
|
||||||
|
"Reply with strictly valid JSON, no prose, no fences."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||||
|
try:
|
||||||
|
out.append(_strip_to_json(retry_text))
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
# After retry: log preview and return None instead of crashing
|
||||||
|
# the whole pipeline. Modules treat None as "skip".
|
||||||
|
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||||
|
print(
|
||||||
|
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
out.append(None)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||||
|
"""Build the shared VLM client.
|
||||||
|
|
||||||
|
Only the ``openai`` backend is supported for now. The shipped workflow
|
||||||
|
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
|
||||||
|
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
|
||||||
|
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
|
||||||
|
optionally auto-spawning the server via ``auto_serve`` /
|
||||||
|
``serve_command``). The former in-process ``vllm`` / ``transformers``
|
||||||
|
backends were removed to keep the support surface to the HF Jobs path.
|
||||||
|
|
||||||
|
For ``stub``, construct :class:`StubVlmClient` directly with a responder
|
||||||
|
callable; it is rejected here to make accidental misuse obvious.
|
||||||
|
"""
|
||||||
|
if config.backend == "openai":
|
||||||
|
return _make_openai_client(config)
|
||||||
|
if config.backend == "stub":
|
||||||
|
raise ValueError(
|
||||||
|
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||||
|
)
|
||||||
|
if config.backend in {"vllm", "transformers"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"backend={config.backend!r} (in-process local model) is not supported for now — "
|
||||||
|
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
|
||||||
|
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
|
||||||
|
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||||
|
"""Backend that talks to any OpenAI-compatible server.
|
||||||
|
|
||||||
|
Compatible with ``vllm serve``, ``transformers serve``,
|
||||||
|
``ktransformers serve``, and hosted endpoints. By default the server
|
||||||
|
is expected to be already running. Set ``auto_serve=True`` to have
|
||||||
|
this client spawn one (default: ``transformers serve``), wait until
|
||||||
|
it's ready, and tear it down on process exit.
|
||||||
|
|
||||||
|
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||||
|
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||||
|
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||||
|
multi-frame ``video_url`` items where supported.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from openai import OpenAI # type: ignore[import-not-found]
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
api_base = config.api_base
|
||||||
|
api_key = config.api_key
|
||||||
|
auto_serve = config.auto_serve
|
||||||
|
api_bases: list[str] = [api_base]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||||
|
f"api_base={api_base} auto_serve={auto_serve}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if auto_serve:
|
||||||
|
if config.parallel_servers > 1:
|
||||||
|
print(
|
||||||
|
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
api_bases = _spawn_parallel_inference_servers(config)
|
||||||
|
elif _server_is_up(api_base):
|
||||||
|
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||||
|
else:
|
||||||
|
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||||
|
api_base = _spawn_inference_server(config)
|
||||||
|
api_bases = [api_base]
|
||||||
|
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||||
|
|
||||||
|
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||||
|
# round-robin counter for parallel mode
|
||||||
|
rr_counter = {"i": 0}
|
||||||
|
|
||||||
|
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||||
|
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||||
|
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||||
|
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||||
|
|
||||||
|
rr_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||||
|
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": config.model_id,
|
||||||
|
"messages": api_messages,
|
||||||
|
"max_tokens": max_tok,
|
||||||
|
"temperature": temp,
|
||||||
|
}
|
||||||
|
extra_body: dict[str, Any] = {}
|
||||||
|
if send_mm_kwargs and mm_kwargs:
|
||||||
|
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||||
|
if config.chat_template_kwargs:
|
||||||
|
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||||
|
if extra_body:
|
||||||
|
kwargs["extra_body"] = extra_body
|
||||||
|
with rr_lock:
|
||||||
|
chosen = clients[rr_counter["i"] % len(clients)]
|
||||||
|
rr_counter["i"] += 1
|
||||||
|
response = chosen.chat.completions.create(**kwargs)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
|
||||||
|
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||||
|
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||||
|
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||||
|
# Parallel fan-out — vllm batches these on the server side.
|
||||||
|
max_workers = min(config.client_concurrency, len(batch))
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
|
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||||
|
return [f.result() for f in futures]
|
||||||
|
|
||||||
|
return _GenericTextClient(_gen, config)
|
||||||
|
|
||||||
|
|
||||||
|
def _bind_serve_port(cmd: str, port: int) -> str:
|
||||||
|
"""Bind a serve command to ``port``: substitute a ``{port}`` placeholder
|
||||||
|
if present, else append ``--port`` when the command omits it (leaving an
|
||||||
|
explicit ``--port`` untouched). Shared by the single- and parallel-server
|
||||||
|
paths so a serve_command never reaches the server with a literal
|
||||||
|
``{port}``."""
|
||||||
|
if "{port}" in cmd:
|
||||||
|
return cmd.replace("{port}", str(port))
|
||||||
|
if "--port" not in cmd:
|
||||||
|
return f"{cmd} --port {port}"
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||||
|
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||||
|
|
||||||
|
Each replica:
|
||||||
|
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||||
|
- listens on ``serve_port + i``
|
||||||
|
- is shut down via the same atexit hook as the single-server path
|
||||||
|
|
||||||
|
Returns the list of ``api_base`` URLs the client should round-robin
|
||||||
|
across.
|
||||||
|
"""
|
||||||
|
n = config.parallel_servers
|
||||||
|
api_bases: list[str] = []
|
||||||
|
procs: list[subprocess.Popen] = []
|
||||||
|
ready_events: list[threading.Event] = []
|
||||||
|
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||||
|
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||||
|
# "Starting vLLM API server" line and the route-listing line. The
|
||||||
|
# HTTP probe below is the ultimate fallback.
|
||||||
|
ready_markers = (
|
||||||
|
"Uvicorn running",
|
||||||
|
"Application startup complete",
|
||||||
|
"Starting vLLM API server",
|
||||||
|
"Available routes are",
|
||||||
|
)
|
||||||
|
# Single lock for all server-stream threads so multibyte chars from
|
||||||
|
# different servers don't interleave and tear UTF-8 sequences.
|
||||||
|
print_lock = threading.Lock()
|
||||||
|
|
||||||
|
base_cmd = config.serve_command or (
|
||||||
|
f"vllm serve {shlex.quote(config.model_id)} "
|
||||||
|
f"--tensor-parallel-size 1 "
|
||||||
|
f"--max-model-len {config.max_model_len or 32768} "
|
||||||
|
f"--uvicorn-log-level warning"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||||
|
for i in range(n):
|
||||||
|
port = config.serve_port + i
|
||||||
|
gpu = i % num_gpus
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||||
|
cmd = _bind_serve_port(base_cmd, port)
|
||||||
|
api_base = f"http://localhost:{port}/v1"
|
||||||
|
api_bases.append(api_base)
|
||||||
|
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
shlex.split(cmd),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
procs.append(proc)
|
||||||
|
ready = threading.Event()
|
||||||
|
ready_events.append(ready)
|
||||||
|
|
||||||
|
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||||
|
# Read whole lines and emit each line atomically under the
|
||||||
|
# shared print_lock so output from N servers stays readable.
|
||||||
|
assert p.stdout is not None
|
||||||
|
for line in iter(p.stdout.readline, ""):
|
||||||
|
with print_lock:
|
||||||
|
sys.stdout.write(f"[server-{idx}] {line}")
|
||||||
|
if not line.endswith(("\n", "\r")):
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
if any(m in line for m in ready_markers):
|
||||||
|
ev.set()
|
||||||
|
|
||||||
|
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||||
|
|
||||||
|
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||||
|
while not ev.is_set() and p.poll() is None:
|
||||||
|
if _server_is_up(base):
|
||||||
|
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||||
|
ev.set()
|
||||||
|
return
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||||
|
|
||||||
|
def _shutdown() -> None:
|
||||||
|
for i, p in enumerate(procs):
|
||||||
|
if p.poll() is None:
|
||||||
|
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||||
|
p.send_signal(signal.SIGINT)
|
||||||
|
for p in procs:
|
||||||
|
try:
|
||||||
|
p.wait(timeout=15)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
p.kill()
|
||||||
|
p.wait(timeout=5)
|
||||||
|
|
||||||
|
atexit.register(_shutdown)
|
||||||
|
|
||||||
|
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||||
|
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||||
|
for i, p in enumerate(procs):
|
||||||
|
if p.poll() is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||||
|
)
|
||||||
|
time.sleep(2)
|
||||||
|
if any(not ev.is_set() for ev in ready_events):
|
||||||
|
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||||
|
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||||
|
return api_bases
|
||||||
|
|
||||||
|
|
||||||
|
def _server_is_up(api_base: str) -> bool:
|
||||||
|
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||||
|
url = api_base.rstrip("/") + "/models"
|
||||||
|
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||||
|
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||||
|
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||||
|
# cannot reach this code path.
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||||
|
return resp.status == 200
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||||
|
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||||
|
accepts ``/v1/models``, and register a shutdown hook.
|
||||||
|
|
||||||
|
Streams the server's stdout/stderr to the parent terminal in
|
||||||
|
real-time on a background thread so users can see model-load
|
||||||
|
progress and errors as they happen.
|
||||||
|
|
||||||
|
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||||
|
"""
|
||||||
|
cmd = config.serve_command
|
||||||
|
if not cmd:
|
||||||
|
cmd = (
|
||||||
|
f"transformers serve {shlex.quote(config.model_id)} "
|
||||||
|
f"--port {config.serve_port} --continuous-batching"
|
||||||
|
)
|
||||||
|
# Bind the single server to ``serve_port`` (what ``api_base`` below
|
||||||
|
# targets): substitute a literal ``{port}`` placeholder, else append
|
||||||
|
# ``--port``. Without this a serve_command carrying ``{port}`` would
|
||||||
|
# reach the server unsubstituted and fail to parse.
|
||||||
|
cmd = _bind_serve_port(cmd, config.serve_port)
|
||||||
|
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||||
|
print(f"[server] launching: {cmd}", flush=True)
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
shlex.split(cmd),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Watch the server output for the uvicorn readiness banner. This is
|
||||||
|
# more reliable than polling /v1/models because transformers serve
|
||||||
|
# rescans its cache on every model-list request, which can exceed
|
||||||
|
# the urllib timeout and trigger an infinite probe loop.
|
||||||
|
ready_event = threading.Event()
|
||||||
|
# See _spawn_parallel_inference_servers for why we accept these.
|
||||||
|
ready_markers = (
|
||||||
|
"Uvicorn running",
|
||||||
|
"Application startup complete",
|
||||||
|
"Starting vLLM API server",
|
||||||
|
"Available routes are",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _probe() -> None:
|
||||||
|
while not ready_event.is_set() and proc.poll() is None:
|
||||||
|
if _server_is_up(api_base):
|
||||||
|
print("[server] ready (http probe)", flush=True)
|
||||||
|
ready_event.set()
|
||||||
|
return
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
threading.Thread(target=_probe, daemon=True).start()
|
||||||
|
|
||||||
|
def _stream_output() -> None:
|
||||||
|
# Read raw chunks instead of iterating lines so tqdm progress
|
||||||
|
# bars (which overwrite using \r) flush in real time.
|
||||||
|
assert proc.stdout is not None
|
||||||
|
buf = ""
|
||||||
|
prefix_started = False
|
||||||
|
while True:
|
||||||
|
ch = proc.stdout.read(1)
|
||||||
|
if ch == "":
|
||||||
|
# process exited; flush any tail
|
||||||
|
if buf:
|
||||||
|
sys.stdout.write(buf)
|
||||||
|
sys.stdout.flush()
|
||||||
|
return
|
||||||
|
if not prefix_started:
|
||||||
|
sys.stdout.write("[server] ")
|
||||||
|
prefix_started = True
|
||||||
|
sys.stdout.write(ch)
|
||||||
|
sys.stdout.flush()
|
||||||
|
buf += ch
|
||||||
|
if ch in ("\n", "\r"):
|
||||||
|
if any(marker in buf for marker in ready_markers):
|
||||||
|
ready_event.set()
|
||||||
|
buf = ""
|
||||||
|
prefix_started = False
|
||||||
|
|
||||||
|
threading.Thread(target=_stream_output, daemon=True).start()
|
||||||
|
|
||||||
|
def _shutdown() -> None:
|
||||||
|
if proc.poll() is None:
|
||||||
|
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||||
|
proc.send_signal(signal.SIGINT)
|
||||||
|
try:
|
||||||
|
proc.wait(timeout=15)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
proc.wait(timeout=5)
|
||||||
|
|
||||||
|
atexit.register(_shutdown)
|
||||||
|
|
||||||
|
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
if proc.poll() is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||||
|
f"See [server] log lines above for the cause."
|
||||||
|
)
|
||||||
|
if ready_event.wait(timeout=2):
|
||||||
|
return api_base
|
||||||
|
proc.terminate()
|
||||||
|
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||||
|
|
||||||
|
|
||||||
|
def _to_openai_messages(
|
||||||
|
messages: Sequence[dict[str, Any]],
|
||||||
|
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||||
|
"""Convert internal messages to OpenAI chat format.
|
||||||
|
|
||||||
|
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||||
|
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||||
|
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||||
|
inside the content blocks (which transformers serve rejects).
|
||||||
|
|
||||||
|
File-URL video blocks are inlined as base64 data URLs.
|
||||||
|
"""
|
||||||
|
out_messages: list[dict[str, Any]] = []
|
||||||
|
mm_kwargs: dict[str, Any] = {}
|
||||||
|
for message in messages:
|
||||||
|
content = message.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
out_messages.append({"role": message["role"], "content": content})
|
||||||
|
continue
|
||||||
|
out_blocks: list[dict[str, Any]] = []
|
||||||
|
for block in content:
|
||||||
|
block_type = block.get("type") if isinstance(block, dict) else None
|
||||||
|
if block_type == "text":
|
||||||
|
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||||
|
elif block_type == "image":
|
||||||
|
out_blocks.append(
|
||||||
|
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||||
|
)
|
||||||
|
elif block_type == "video":
|
||||||
|
frames = block.get("video", [])
|
||||||
|
for img in frames:
|
||||||
|
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||||
|
elif block_type == "video_url":
|
||||||
|
video_url = dict(block["video_url"])
|
||||||
|
url = video_url.get("url", "")
|
||||||
|
if url.startswith("file://"):
|
||||||
|
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||||
|
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||||
|
fps = block.get("fps")
|
||||||
|
if fps is not None:
|
||||||
|
mm_kwargs["fps"] = fps
|
||||||
|
else:
|
||||||
|
out_blocks.append(block)
|
||||||
|
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||||
|
return out_messages, mm_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _file_to_data_url(path: str) -> str:
|
||||||
|
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||||
|
return f"data:video/mp4;base64,{b64}"
|
||||||
|
|
||||||
|
|
||||||
|
def _pil_to_data_url(image: Any) -> str:
|
||||||
|
"""Encode a PIL.Image as a base64 data URL."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
image.save(buf, format="PNG")
|
||||||
|
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
return f"data:image/png;base64,{b64}"
|
||||||
@@ -0,0 +1,341 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Final parquet rewrite.
|
||||||
|
|
||||||
|
For every episode the writer:
|
||||||
|
|
||||||
|
1. reads the staged module outputs,
|
||||||
|
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||||
|
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||||
|
3. sorts each slice deterministically,
|
||||||
|
4. broadcasts the persistent slice across every frame in the episode,
|
||||||
|
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||||
|
exactly equals that frame's timestamp,
|
||||||
|
6. drops the legacy ``subtask_index`` column,
|
||||||
|
7. writes the parquet shard back in place.
|
||||||
|
|
||||||
|
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||||
|
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||||
|
struct for every speech atom. The tool *schema* (the description
|
||||||
|
of the ``say`` function and its parameters) is a fixed code constant —
|
||||||
|
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||||
|
it directly rather than reading a redundant per-row column.
|
||||||
|
|
||||||
|
Invariants enforced here (and re-checked by the validator):
|
||||||
|
|
||||||
|
- per-episode persistent slice is byte-identical across every frame;
|
||||||
|
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||||
|
(timestamps come straight from the source parquet — never recomputed);
|
||||||
|
- every row passes ``column_for_style(style)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
|
||||||
|
from lerobot.datasets.language import (
|
||||||
|
EVENT_ONLY_STYLES,
|
||||||
|
LANGUAGE_EVENTS,
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
PERSISTENT_STYLES,
|
||||||
|
column_for_style,
|
||||||
|
validate_camera_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .reader import EpisodeRecord
|
||||||
|
from .staging import EpisodeStaging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Tool schema constants live in lerobot.datasets.language — single
|
||||||
|
# source of truth. Re-exported here so existing imports
|
||||||
|
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||||
|
# keep working.
|
||||||
|
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||||
|
|
||||||
|
|
||||||
|
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||||
|
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||||
|
|
||||||
|
|
||||||
|
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||||
|
# events are bucketed per-frame, but within a frame we still want determinism
|
||||||
|
return (
|
||||||
|
row.get("style") or "",
|
||||||
|
row.get("role") or "",
|
||||||
|
row.get("camera") or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
|
||||||
|
"""Coerce a staged row into the language-column struct shape.
|
||||||
|
|
||||||
|
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
|
||||||
|
writer infers the parquet struct schema from insertion order, so
|
||||||
|
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
|
||||||
|
"""
|
||||||
|
camera = row.get("camera")
|
||||||
|
validate_camera_field(style, camera)
|
||||||
|
out: dict[str, Any] = {
|
||||||
|
"role": str(row["role"]),
|
||||||
|
"content": None if row.get("content") is None else str(row["content"]),
|
||||||
|
"style": style,
|
||||||
|
}
|
||||||
|
if with_timestamp:
|
||||||
|
out["timestamp"] = float(row["timestamp"])
|
||||||
|
out["camera"] = None if camera is None else str(camera)
|
||||||
|
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Coerce a staged row into the persistent column's struct shape."""
|
||||||
|
style = row.get("style")
|
||||||
|
if style not in PERSISTENT_STYLES:
|
||||||
|
raise ValueError(
|
||||||
|
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||||
|
"row would be misrouted under column_for_style()"
|
||||||
|
)
|
||||||
|
if "timestamp" not in row:
|
||||||
|
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||||
|
if "role" not in row:
|
||||||
|
# Friendly error from the writer instead of a raw KeyError below;
|
||||||
|
# the validator doesn't check ``role`` yet.
|
||||||
|
raise ValueError(f"persistent row missing role: {row!r}")
|
||||||
|
return _normalize_row(row, style, with_timestamp=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||||
|
style = row.get("style")
|
||||||
|
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||||
|
raise ValueError(
|
||||||
|
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||||
|
)
|
||||||
|
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||||
|
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||||
|
if "role" not in row:
|
||||||
|
raise ValueError(f"event row missing role: {row!r}")
|
||||||
|
return _normalize_row(row, style, with_timestamp=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||||
|
return list(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||||
|
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||||
|
has_content = row.get("content") is not None
|
||||||
|
has_tools = row.get("tool_calls") is not None
|
||||||
|
if not (has_content or has_tools):
|
||||||
|
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||||
|
if row.get("style") is None and not has_tools:
|
||||||
|
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||||
|
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||||
|
if row.get("style") is not None:
|
||||||
|
return # not a speech atom
|
||||||
|
if row.get("role") != "assistant":
|
||||||
|
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||||
|
if row.get("content") is not None:
|
||||||
|
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||||
|
tool_calls = row.get("tool_calls")
|
||||||
|
if not tool_calls or not isinstance(tool_calls, list):
|
||||||
|
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||||
|
first = tool_calls[0]
|
||||||
|
if not isinstance(first, dict):
|
||||||
|
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||||
|
if first.get("type") != "function":
|
||||||
|
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||||
|
fn = first.get("function") or {}
|
||||||
|
if fn.get("name") != "say":
|
||||||
|
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||||
|
args = fn.get("arguments") or {}
|
||||||
|
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||||
|
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LanguageColumnsWriter:
|
||||||
|
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||||
|
|
||||||
|
drop_existing_subtask_index: bool = True
|
||||||
|
|
||||||
|
def write_all(
|
||||||
|
self,
|
||||||
|
records: Sequence[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
root: Path,
|
||||||
|
) -> list[Path]:
|
||||||
|
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||||
|
for record in records:
|
||||||
|
episodes_by_path[record.data_path].append(record)
|
||||||
|
|
||||||
|
written: list[Path] = []
|
||||||
|
for path, eps in episodes_by_path.items():
|
||||||
|
self._rewrite_one(path, eps, staging_dir, root)
|
||||||
|
written.append(path)
|
||||||
|
return written
|
||||||
|
|
||||||
|
def _rewrite_one(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
episodes: Sequence[EpisodeRecord],
|
||||||
|
staging_dir: Path,
|
||||||
|
root: Path,
|
||||||
|
) -> None:
|
||||||
|
table = pq.read_table(path)
|
||||||
|
n_rows = table.num_rows
|
||||||
|
|
||||||
|
# Ensure we cover every episode in the file. Episodes that don't have
|
||||||
|
# staging artifacts are passed through with empty annotation lists —
|
||||||
|
# this keeps the writer idempotent and safe for partial reruns.
|
||||||
|
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||||
|
for record in episodes:
|
||||||
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
|
staged_per_ep[record.episode_index] = staging.read_all()
|
||||||
|
|
||||||
|
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||||
|
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||||
|
|
||||||
|
for ep_index, ep_staged in staged_per_ep.items():
|
||||||
|
persistent_rows: list[dict[str, Any]] = []
|
||||||
|
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||||
|
for _module_name, rows in ep_staged.items():
|
||||||
|
for row in rows:
|
||||||
|
style = row.get("style")
|
||||||
|
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||||
|
persistent_rows.append(row)
|
||||||
|
else:
|
||||||
|
event_rows.append(row)
|
||||||
|
|
||||||
|
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||||
|
normalized_persistent = []
|
||||||
|
for r in persistent_rows:
|
||||||
|
_validate_atom_invariants(r)
|
||||||
|
_validate_speech_atom(r)
|
||||||
|
normalized_persistent.append(_normalize_persistent_row(r))
|
||||||
|
persistent_by_ep[ep_index] = normalized_persistent
|
||||||
|
|
||||||
|
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||||
|
for r in event_rows:
|
||||||
|
_validate_atom_invariants(r)
|
||||||
|
_validate_speech_atom(r)
|
||||||
|
ts = float(r["timestamp"])
|
||||||
|
buckets[ts].append(_normalize_event_row(r))
|
||||||
|
for ts in list(buckets.keys()):
|
||||||
|
buckets[ts].sort(key=_row_event_sort_key)
|
||||||
|
events_by_ep_ts[ep_index] = buckets
|
||||||
|
|
||||||
|
episode_col = (
|
||||||
|
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||||
|
)
|
||||||
|
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||||
|
if episode_col is None or ts_col is None:
|
||||||
|
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||||
|
|
||||||
|
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||||
|
per_row_events: list[list[dict[str, Any]]] = []
|
||||||
|
for i in range(n_rows):
|
||||||
|
ep = episode_col[i]
|
||||||
|
ts = float(ts_col[i])
|
||||||
|
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||||
|
buckets = events_by_ep_ts.get(ep, {})
|
||||||
|
per_row_events.append(buckets.get(ts, []))
|
||||||
|
|
||||||
|
new_table = self._materialize_table(
|
||||||
|
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||||
|
)
|
||||||
|
# Re-emit one row group per episode (a bulk pq.write_table would collapse
|
||||||
|
# them into one). Write to a sibling tmp path and atomically rename so a
|
||||||
|
# crash mid-write can't leave a half-written shard.
|
||||||
|
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
write_table_one_row_group_per_episode(new_table, tmp_path)
|
||||||
|
tmp_path.replace(path)
|
||||||
|
|
||||||
|
def _materialize_table(
|
||||||
|
self,
|
||||||
|
table: pa.Table,
|
||||||
|
persistent: list[list[dict[str, Any]]],
|
||||||
|
events: list[list[dict[str, Any]]],
|
||||||
|
*,
|
||||||
|
drop_old: bool,
|
||||||
|
) -> pa.Table:
|
||||||
|
cols = []
|
||||||
|
names = []
|
||||||
|
for name in table.column_names:
|
||||||
|
if drop_old and name == "subtask_index":
|
||||||
|
continue
|
||||||
|
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||||
|
continue # we'll re-add canonical versions
|
||||||
|
# Strip any legacy ``tools`` column previously emitted by older
|
||||||
|
# writers — the schema no longer uses it (constant lives in
|
||||||
|
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||||
|
if name == "tools":
|
||||||
|
continue
|
||||||
|
cols.append(table.column(name))
|
||||||
|
names.append(name)
|
||||||
|
|
||||||
|
# We let pyarrow infer struct/list schema rather than passing the
|
||||||
|
# canonical type from `lerobot.datasets.language` directly: that type
|
||||||
|
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||||
|
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||||
|
# current pyarrow versions. The inferred schema round-trips through
|
||||||
|
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||||
|
# exercises the same flow.
|
||||||
|
persistent_arr = pa.array(persistent)
|
||||||
|
events_arr = pa.array(events)
|
||||||
|
|
||||||
|
cols.extend([persistent_arr, events_arr])
|
||||||
|
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||||
|
|
||||||
|
return pa.Table.from_arrays(cols, names=names)
|
||||||
|
|
||||||
|
|
||||||
|
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||||
|
"""Build a canonical speech tool-call atom for the events column."""
|
||||||
|
return {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"style": None,
|
||||||
|
"timestamp": float(timestamp),
|
||||||
|
"camera": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "say",
|
||||||
|
"arguments": {"text": text},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -442,11 +442,12 @@ class OpenCVCamera(Camera):
|
|||||||
|
|
||||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||||
"""
|
"""
|
||||||
if self.stop_event is None:
|
stop_event = self.stop_event
|
||||||
|
if stop_event is None:
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not self.stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
raw_frame = self._read_from_hardware()
|
raw_frame = self._read_from_hardware()
|
||||||
processed_frame = self._postprocess_image(raw_frame)
|
processed_frame = self._postprocess_image(raw_frame)
|
||||||
|
|||||||
@@ -471,11 +471,12 @@ class RealSenseCamera(Camera):
|
|||||||
|
|
||||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||||
"""
|
"""
|
||||||
if self.stop_event is None:
|
stop_event = self.stop_event
|
||||||
|
if stop_event is None:
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not self.stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
frame = self._read_from_hardware()
|
frame = self._read_from_hardware()
|
||||||
color_frame_raw = frame.get_color_frame()
|
color_frame_raw = frame.get_color_frame()
|
||||||
|
|||||||
@@ -246,11 +246,12 @@ class ZMQCamera(Camera):
|
|||||||
"""
|
"""
|
||||||
Internal loop run by the background thread for asynchronous reading.
|
Internal loop run by the background thread for asynchronous reading.
|
||||||
"""
|
"""
|
||||||
if self.stop_event is None:
|
stop_event = self.stop_event
|
||||||
|
if stop_event is None:
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not self.stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
frame = self._read_from_hardware()
|
frame = self._read_from_hardware()
|
||||||
capture_time = time.perf_counter()
|
capture_time = time.perf_counter()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Teleoperator smooth handover helpers
|
||||||
|
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||||
|
# being a root module.
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_supports_feedback(teleop) -> bool:
|
||||||
|
"""Return True when the teleop can receive position feedback (is actuated).
|
||||||
|
|
||||||
|
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||||
|
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||||
|
|
||||||
|
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
bool(teleop.feedback_features)
|
||||||
|
and hasattr(teleop, "disable_torque")
|
||||||
|
and hasattr(teleop, "enable_torque")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||||
|
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||||
|
|
||||||
|
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||||
|
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||||
|
|
||||||
|
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||||
|
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||||
|
the robot action key space directly.
|
||||||
|
|
||||||
|
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||||
|
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||||
|
"""
|
||||||
|
teleop.enable_torque()
|
||||||
|
current = teleop.get_action()
|
||||||
|
steps = max(int(duration_s * fps), 1)
|
||||||
|
|
||||||
|
for step in range(steps + 1):
|
||||||
|
t = step / steps
|
||||||
|
interp = {
|
||||||
|
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||||
|
}
|
||||||
|
teleop.send_feedback(interp)
|
||||||
|
time.sleep(1 / fps)
|
||||||
|
|
||||||
|
|
||||||
|
def follower_smooth_move_to(
|
||||||
|
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||||
|
) -> None:
|
||||||
|
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||||
|
|
||||||
|
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||||
|
the follower, the follower is brought to the teleop's current pose so the
|
||||||
|
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||||
|
|
||||||
|
Both ``current`` and ``target`` must be in the robot action key space
|
||||||
|
(i.e. the output of ``robot_action_processor``).
|
||||||
|
"""
|
||||||
|
steps = max(int(duration_s * fps), 1)
|
||||||
|
|
||||||
|
for step in range(steps + 1):
|
||||||
|
t = step / steps
|
||||||
|
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||||
|
robot.send_action(interp)
|
||||||
|
time.sleep(1 / fps)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.optim import (
|
from lerobot.optim import (
|
||||||
load_optimizer_state,
|
load_optimizer_state,
|
||||||
|
load_optimizer_state_dict,
|
||||||
load_scheduler_state,
|
load_scheduler_state,
|
||||||
save_optimizer_state,
|
save_optimizer_state,
|
||||||
save_scheduler_state,
|
save_scheduler_state,
|
||||||
@@ -49,8 +50,19 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
|||||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||||
|
|
||||||
|
|
||||||
def save_training_step(step: int, save_dir: Path) -> None:
|
def save_training_step(
|
||||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
|
||||||
|
) -> None:
|
||||||
|
state: dict = {"step": step}
|
||||||
|
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
|
||||||
|
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
|
||||||
|
# produced `step`, since both scale how many sampler positions a step consumes (see
|
||||||
|
# compute_sampler_state).
|
||||||
|
if num_processes is not None:
|
||||||
|
state["num_processes"] = num_processes
|
||||||
|
if batch_size is not None:
|
||||||
|
state["batch_size"] = batch_size
|
||||||
|
write_json(state, save_dir / TRAINING_STEP)
|
||||||
|
|
||||||
|
|
||||||
def load_training_step(save_dir: Path) -> int:
|
def load_training_step(save_dir: Path) -> int:
|
||||||
@@ -58,6 +70,16 @@ def load_training_step(save_dir: Path) -> int:
|
|||||||
return training_step["step"]
|
return training_step["step"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||||
|
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||||
|
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
|
||||||
|
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
|
||||||
|
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
|
||||||
|
|
||||||
|
|
||||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||||
if last_checkpoint_dir.is_symlink():
|
if last_checkpoint_dir.is_symlink():
|
||||||
@@ -75,6 +97,10 @@ def save_checkpoint(
|
|||||||
scheduler: LRScheduler | None = None,
|
scheduler: LRScheduler | None = None,
|
||||||
preprocessor: PolicyProcessorPipeline | None = None,
|
preprocessor: PolicyProcessorPipeline | None = None,
|
||||||
postprocessor: PolicyProcessorPipeline | None = None,
|
postprocessor: PolicyProcessorPipeline | None = None,
|
||||||
|
num_processes: int | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
model_state_dict: dict | None = None,
|
||||||
|
optim_state_dict: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function creates the following directory structure:
|
"""This function creates the following directory structure:
|
||||||
|
|
||||||
@@ -100,9 +126,22 @@ def save_checkpoint(
|
|||||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||||
|
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||||
|
resume. Defaults to None (not recorded).
|
||||||
|
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||||
|
resume. Defaults to None (not recorded).
|
||||||
|
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
|
||||||
|
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
|
||||||
|
cross-rank collective and passes it here so rank 0 can write it directly. It holds
|
||||||
|
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
|
||||||
|
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
|
||||||
|
optim_state_dict: Pre-gathered full (unsharded) optimizer state dict. Required under FSDP
|
||||||
|
(gathered alongside `model_state_dict` via `gather_fsdp_state_dicts`); saved in the same
|
||||||
|
safetensors format as the single-GPU path. When None, `optimizer.state_dict()` is used.
|
||||||
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||||
policy.save_pretrained(pretrained_dir)
|
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
|
||||||
cfg.save_pretrained(pretrained_dir)
|
cfg.save_pretrained(pretrained_dir)
|
||||||
if cfg.peft is not None:
|
if cfg.peft is not None:
|
||||||
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
||||||
@@ -112,7 +151,15 @@ def save_checkpoint(
|
|||||||
preprocessor.save_pretrained(pretrained_dir)
|
preprocessor.save_pretrained(pretrained_dir)
|
||||||
if postprocessor is not None:
|
if postprocessor is not None:
|
||||||
postprocessor.save_pretrained(pretrained_dir)
|
postprocessor.save_pretrained(pretrained_dir)
|
||||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
save_training_state(
|
||||||
|
checkpoint_dir,
|
||||||
|
step,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
num_processes=num_processes,
|
||||||
|
batch_size=batch_size,
|
||||||
|
optim_state_dict=optim_state_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_training_state(
|
def save_training_state(
|
||||||
@@ -120,6 +167,9 @@ def save_training_state(
|
|||||||
train_step: int,
|
train_step: int,
|
||||||
optimizer: Optimizer | None = None,
|
optimizer: Optimizer | None = None,
|
||||||
scheduler: LRScheduler | None = None,
|
scheduler: LRScheduler | None = None,
|
||||||
|
num_processes: int | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
optim_state_dict: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||||
@@ -131,19 +181,23 @@ def save_training_state(
|
|||||||
Defaults to None.
|
Defaults to None.
|
||||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||||
|
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||||
|
optim_state_dict: Pre-gathered full optimizer state dict (for FSDP). Saved instead of
|
||||||
|
`optimizer.state_dict()` when provided. Defaults to None.
|
||||||
"""
|
"""
|
||||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
save_training_step(train_step, save_dir)
|
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||||
save_rng_state(save_dir)
|
save_rng_state(save_dir)
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
save_optimizer_state(optimizer, save_dir)
|
save_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
save_scheduler_state(scheduler, save_dir)
|
save_scheduler_state(scheduler, save_dir)
|
||||||
|
|
||||||
|
|
||||||
def load_training_state(
|
def load_training_state(
|
||||||
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
|
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None, load_optimizer: bool = True
|
||||||
) -> tuple[int, Optimizer, LRScheduler | None]:
|
) -> tuple[int, Optimizer, LRScheduler | None]:
|
||||||
"""
|
"""
|
||||||
Loads the training step, optimizer state, scheduler state, and rng state.
|
Loads the training step, optimizer state, scheduler state, and rng state.
|
||||||
@@ -153,6 +207,10 @@ def load_training_state(
|
|||||||
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
||||||
optimizer (Optimizer): The optimizer to load the state_dict to.
|
optimizer (Optimizer): The optimizer to load the state_dict to.
|
||||||
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
||||||
|
load_optimizer (bool, optional): Whether to load the optimizer state from disk. Defaults to
|
||||||
|
True. Set to False under FSDP, where the sharded optimizer state must be loaded after
|
||||||
|
`accelerator.prepare()` via `load_fsdp_optimizer_state` (the optimizer is returned
|
||||||
|
untouched here).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
||||||
@@ -167,8 +225,61 @@ def load_training_state(
|
|||||||
|
|
||||||
load_rng_state(training_state_dir)
|
load_rng_state(training_state_dir)
|
||||||
step = load_training_step(training_state_dir)
|
step = load_training_step(training_state_dir)
|
||||||
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
if load_optimizer:
|
||||||
|
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
||||||
|
|
||||||
return step, optimizer, scheduler
|
return step, optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def gather_fsdp_state_dicts(model, optimizer) -> tuple[dict, dict]:
|
||||||
|
"""Gather the full (unsharded) model and optimizer state dicts under FSDP.
|
||||||
|
|
||||||
|
`model.state_dict()` and `FSDP.optim_state_dict(...)` are cross-rank collectives, so this must be
|
||||||
|
called on *every* rank with the prepared (FSDP-wrapped) `model` and `optimizer`. With
|
||||||
|
`rank0_only=True` and `offload_to_cpu=True`, every rank runs the all-gather but only rank 0
|
||||||
|
materializes the full dicts (the others get empty dicts) and they are kept on CPU to bound GPU
|
||||||
|
memory. The returned optimizer state dict is keyed by parameter FQNs and is world-size
|
||||||
|
independent; `load_fsdp_optimizer_state` reshards it on resume.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(model_state_dict, optim_state_dict): full dicts on rank 0, empty dicts on other ranks.
|
||||||
|
"""
|
||||||
|
from torch.distributed.fsdp import (
|
||||||
|
FullOptimStateDictConfig,
|
||||||
|
FullStateDictConfig,
|
||||||
|
FullyShardedDataParallel as FSDP, # noqa F401
|
||||||
|
StateDictType,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||||
|
optim_cfg = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||||
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||||
|
return model_state_dict, optim_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
|
||||||
|
"""Load the FSDP optimizer state (saved as safetensors) and reshard it into the optimizer.
|
||||||
|
|
||||||
|
This is a cross-rank collective and must be called on every rank *after* `accelerator.prepare()`
|
||||||
|
with the prepared (FSDP-wrapped) `model` and `optimizer`. The saved state is the full,
|
||||||
|
world-size-independent optimizer state (keyed by parameter FQNs); `FSDP.optim_state_dict_to_load`
|
||||||
|
reshards it to the current FSDP topology, so resume on a different number of GPUs works.
|
||||||
|
"""
|
||||||
|
from torch.distributed.fsdp import (
|
||||||
|
FullOptimStateDictConfig,
|
||||||
|
FullStateDictConfig,
|
||||||
|
FullyShardedDataParallel as FSDP, # noqa F401
|
||||||
|
StateDictType,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Every rank reads the same full state from the (shared) checkpoint dir, so rank0_only=False.
|
||||||
|
full_osd = load_optimizer_state_dict(checkpoint_dir / TRAINING_STATE_DIR)
|
||||||
|
state_cfg = FullStateDictConfig(rank0_only=False)
|
||||||
|
optim_cfg = FullOptimStateDictConfig(rank0_only=False)
|
||||||
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
||||||
|
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
|
||||||
|
optimizer.load_state_dict(sharded_osd)
|
||||||
|
|||||||
@@ -180,24 +180,26 @@ class WandBLogger:
|
|||||||
self._wandb_custom_step_key.add(new_custom_key)
|
self._wandb_custom_step_key.add(new_custom_key)
|
||||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||||
|
|
||||||
|
batch_data = {}
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
|
# Skip the custom step key here, it's added to the batch below.
|
||||||
|
if custom_step_key is not None and k == custom_step_key:
|
||||||
|
continue
|
||||||
|
|
||||||
if not isinstance(v, (int | float | str)):
|
if not isinstance(v, (int | float | str)):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Do not log the custom step key itself.
|
batch_data[f"{mode}/{k}"] = v
|
||||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
if batch_data:
|
||||||
if custom_step_key is not None:
|
if custom_step_key is not None:
|
||||||
value_custom_step = d[custom_step_key]
|
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
|
||||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
self._wandb.log(batch_data)
|
||||||
self._wandb.log(data)
|
else:
|
||||||
continue
|
self._wandb.log(data=batch_data, step=step)
|
||||||
|
|
||||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
|
|||||||
video: bool = True
|
video: bool = True
|
||||||
# Upload dataset to Hugging Face hub.
|
# Upload dataset to Hugging Face hub.
|
||||||
push_to_hub: bool = True
|
push_to_hub: bool = True
|
||||||
# Upload on private repository on the Hugging Face hub.
|
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
|
||||||
private: bool = False
|
private: bool | None = None
|
||||||
# Add tags to your dataset on the hub.
|
# Add tags to your dataset on the hub.
|
||||||
tags: list[str] | None = None
|
tags: list[str] | None = None
|
||||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||||
|
|||||||
@@ -39,8 +39,12 @@ class DatasetConfig:
|
|||||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||||
return_uint8: bool = False
|
return_uint8: bool = False
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
|
||||||
|
eval_split: float = 0.0
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
if not (0.0 <= self.eval_split < 1.0):
|
||||||
|
raise ValueError(f"eval_split must be in [0.0, 1.0), got {self.eval_split}")
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
if any(ep < 0 for ep in self.episodes):
|
if any(ep < 0 for ep in self.episodes):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
|
|||||||
remaining = config_data[field]
|
remaining = config_data[field]
|
||||||
if remaining:
|
if remaining:
|
||||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||||
else:
|
del config_data[field]
|
||||||
del config_data[field]
|
|
||||||
modified = True
|
modified = True
|
||||||
|
|
||||||
if not modified:
|
if not modified:
|
||||||
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
|||||||
cli_args = filter_arg("config_path", cli_args)
|
cli_args = filter_arg("config_path", cli_args)
|
||||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
if config_path_cli:
|
||||||
|
cli_args = filter_arg("config_path", cli_args)
|
||||||
|
cfg = draccus.parse(
|
||||||
|
config_class=argtype,
|
||||||
|
config_path=config_path_cli or config_path,
|
||||||
|
args=cli_args,
|
||||||
|
)
|
||||||
response = fn(cfg, *args, **kwargs)
|
response = fn(cfg, *args, **kwargs)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
|||||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||||
pretrained_path: Path | None = None
|
pretrained_path: Path | None = None
|
||||||
|
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
|
||||||
|
pretrained_revision: str | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if not self.device or not is_torch_device_available(self.device):
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
|||||||
device: str | None = None
|
device: str | None = None
|
||||||
|
|
||||||
pretrained_path: str | None = None
|
pretrained_path: str | None = None
|
||||||
|
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
|
||||||
|
pretrained_revision: str | None = None
|
||||||
|
|
||||||
push_to_hub: bool = False
|
push_to_hub: bool = False
|
||||||
repo_id: str | None = None
|
repo_id: str | None = None
|
||||||
|
|||||||
@@ -100,8 +100,13 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
prefetch_factor: int = 4
|
prefetch_factor: int = 4
|
||||||
persistent_workers: bool = True
|
persistent_workers: bool = True
|
||||||
steps: int = 100_000
|
steps: int = 100_000
|
||||||
eval_freq: int = 20_000
|
# Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
|
||||||
|
env_eval_freq: int = 20_000
|
||||||
log_freq: int = 200
|
log_freq: int = 200
|
||||||
|
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
|
||||||
|
eval_steps: int = 0
|
||||||
|
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
|
||||||
|
max_eval_samples: int = 0
|
||||||
tolerance_s: float = 1e-4
|
tolerance_s: float = 1e-4
|
||||||
save_checkpoint: bool = True
|
save_checkpoint: bool = True
|
||||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||||
@@ -177,6 +182,12 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
active_cfg = self.trainable_config
|
active_cfg = self.trainable_config
|
||||||
|
if self.rename_map and active_cfg.pretrained_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`rename_map` requires a pretrained policy checkpoint. "
|
||||||
|
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
|
||||||
|
)
|
||||||
|
|
||||||
if not self.job_name:
|
if not self.job_name:
|
||||||
if self.env is None:
|
if self.env is None:
|
||||||
self.job_name = f"{active_cfg.type}"
|
self.job_name = f"{active_cfg.type}"
|
||||||
@@ -202,6 +213,9 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
self.optimizer = active_cfg.get_optimizer_preset()
|
self.optimizer = active_cfg.get_optimizer_preset()
|
||||||
self.scheduler = active_cfg.get_scheduler_preset()
|
self.scheduler = active_cfg.get_scheduler_preset()
|
||||||
|
|
||||||
|
if self.eval_steps > 0 and self.dataset.eval_split == 0.0:
|
||||||
|
raise ValueError("eval_steps > 0 requires dataset.eval_split > 0.0 to hold out eval data.")
|
||||||
|
|
||||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from .dataset_tools import (
|
|||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
from .factory import make_dataset, resolve_delta_timestamps
|
from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
|
||||||
from .image_writer import safe_stop_image_writer
|
from .image_writer import safe_stop_image_writer
|
||||||
from .io_utils import load_episodes, write_stats
|
from .io_utils import load_episodes, write_stats
|
||||||
from .language import (
|
from .language import (
|
||||||
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
|||||||
from .multi_dataset import MultiLeRobotDataset
|
from .multi_dataset import MultiLeRobotDataset
|
||||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||||
from .sampler import EpisodeAwareSampler
|
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||||
from .streaming_dataset import StreamingLeRobotDataset
|
from .streaming_dataset import StreamingLeRobotDataset
|
||||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||||
from .video_utils import VideoEncodingManager
|
from .video_utils import VideoEncodingManager
|
||||||
@@ -82,12 +82,14 @@ __all__ = [
|
|||||||
"aggregate_stats",
|
"aggregate_stats",
|
||||||
"convert_image_to_video_dataset",
|
"convert_image_to_video_dataset",
|
||||||
"create_initial_features",
|
"create_initial_features",
|
||||||
|
"compute_sampler_state",
|
||||||
"create_lerobot_dataset_card",
|
"create_lerobot_dataset_card",
|
||||||
"column_for_style",
|
"column_for_style",
|
||||||
"delete_episodes",
|
"delete_episodes",
|
||||||
"get_feature_stats",
|
"get_feature_stats",
|
||||||
"load_episodes",
|
"load_episodes",
|
||||||
"make_dataset",
|
"make_dataset",
|
||||||
|
"make_train_eval_datasets",
|
||||||
"merge_datasets",
|
"merge_datasets",
|
||||||
"modify_features",
|
"modify_features",
|
||||||
"modify_tasks",
|
"modify_tasks",
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
|
|||||||
from .io_utils import (
|
from .io_utils import (
|
||||||
get_file_size_in_mb,
|
get_file_size_in_mb,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
|
to_parquet_one_row_group_per_episode,
|
||||||
to_parquet_with_hf_images,
|
to_parquet_with_hf_images,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
@@ -286,6 +287,8 @@ def aggregate_datasets(
|
|||||||
data_files_size_in_mb: int | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: int | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
|
concatenate_videos: bool = True,
|
||||||
|
concatenate_data: bool = True,
|
||||||
):
|
):
|
||||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||||
|
|
||||||
@@ -303,6 +306,8 @@ def aggregate_datasets(
|
|||||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||||
|
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||||
|
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||||
"""
|
"""
|
||||||
logging.info("Start aggregate_datasets")
|
logging.info("Start aggregate_datasets")
|
||||||
|
|
||||||
@@ -351,8 +356,12 @@ def aggregate_datasets(
|
|||||||
dst_meta.episodes = {}
|
dst_meta.episodes = {}
|
||||||
|
|
||||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
videos_idx = aggregate_videos(
|
||||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
|
||||||
|
)
|
||||||
|
data_idx = aggregate_data(
|
||||||
|
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
|
||||||
|
)
|
||||||
|
|
||||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||||
|
|
||||||
@@ -367,7 +376,9 @@ def aggregate_datasets(
|
|||||||
logging.info("Aggregation complete.")
|
logging.info("Aggregation complete.")
|
||||||
|
|
||||||
|
|
||||||
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
def aggregate_videos(
|
||||||
|
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
|
||||||
|
):
|
||||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||||
|
|
||||||
Handles video file concatenation and rotation based on file size limits.
|
Handles video file concatenation and rotation based on file size limits.
|
||||||
@@ -379,6 +390,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
videos_idx: Dictionary tracking video chunk and file indices.
|
videos_idx: Dictionary tracking video chunk and file indices.
|
||||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||||
|
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||||
Returns:
|
Returns:
|
||||||
dict: Updated videos_idx with current chunk and file indices.
|
dict: Updated videos_idx with current chunk and file indices.
|
||||||
"""
|
"""
|
||||||
@@ -439,7 +451,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
src_size = get_file_size_in_mb(src_path)
|
src_size = get_file_size_in_mb(src_path)
|
||||||
dst_size = get_file_size_in_mb(dst_path)
|
dst_size = get_file_size_in_mb(dst_path)
|
||||||
|
|
||||||
if dst_size + src_size >= video_files_size_in_mb:
|
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
|
||||||
# Rotate to a new file - offset is 0
|
# Rotate to a new file - offset is 0
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||||
dst_key = (chunk_idx, file_idx)
|
dst_key = (chunk_idx, file_idx)
|
||||||
@@ -477,7 +489,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
return videos_idx
|
return videos_idx
|
||||||
|
|
||||||
|
|
||||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
|
||||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||||
|
|
||||||
Reads source data files, updates indices to match the aggregated dataset,
|
Reads source data files, updates indices to match the aggregated dataset,
|
||||||
@@ -493,6 +505,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
data_idx: Dictionary tracking data chunk and file indices.
|
data_idx: Dictionary tracking data chunk and file indices.
|
||||||
data_files_size_in_mb: Maximum size for data files in MB.
|
data_files_size_in_mb: Maximum size for data files in MB.
|
||||||
chunk_size: Maximum number of files per chunk.
|
chunk_size: Maximum number of files per chunk.
|
||||||
|
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Updated data_idx with current chunk and file indices.
|
dict: Updated data_idx with current chunk and file indices.
|
||||||
@@ -538,6 +551,8 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
contains_images=contains_images,
|
contains_images=contains_images,
|
||||||
aggr_root=dst_meta.root,
|
aggr_root=dst_meta.root,
|
||||||
hf_features=hf_features,
|
hf_features=hf_features,
|
||||||
|
concatenate=concatenate_data,
|
||||||
|
one_row_group_per_episode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Record the mapping from source to actual destination
|
# Record the mapping from source to actual destination
|
||||||
@@ -614,6 +629,8 @@ def append_or_create_parquet_file(
|
|||||||
contains_images: bool = False,
|
contains_images: bool = False,
|
||||||
aggr_root: Path = None,
|
aggr_root: Path = None,
|
||||||
hf_features: datasets.Features | None = None,
|
hf_features: datasets.Features | None = None,
|
||||||
|
concatenate: bool = True,
|
||||||
|
one_row_group_per_episode: bool = False,
|
||||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||||
|
|
||||||
@@ -630,6 +647,9 @@ def append_or_create_parquet_file(
|
|||||||
contains_images: Whether the data contains images requiring special handling.
|
contains_images: Whether the data contains images requiring special handling.
|
||||||
aggr_root: Root path for the aggregated dataset.
|
aggr_root: Root path for the aggregated dataset.
|
||||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||||
|
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||||
|
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
|
||||||
|
the episodes-metadata parquet (already one row per episode).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||||
@@ -642,6 +662,8 @@ def append_or_create_parquet_file(
|
|||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||||
|
elif one_row_group_per_episode:
|
||||||
|
to_parquet_one_row_group_per_episode(df, dst_path)
|
||||||
else:
|
else:
|
||||||
df.to_parquet(dst_path)
|
df.to_parquet(dst_path)
|
||||||
return idx, (dst_chunk, dst_file)
|
return idx, (dst_chunk, dst_file)
|
||||||
@@ -649,7 +671,7 @@ def append_or_create_parquet_file(
|
|||||||
src_size = get_parquet_file_size_in_mb(src_path)
|
src_size = get_parquet_file_size_in_mb(src_path)
|
||||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||||
|
|
||||||
if dst_size + src_size >= max_mb:
|
if not concatenate or dst_size + src_size >= max_mb:
|
||||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||||
@@ -668,6 +690,8 @@ def append_or_create_parquet_file(
|
|||||||
|
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||||
|
elif one_row_group_per_episode:
|
||||||
|
to_parquet_one_row_group_per_episode(final_df, target_path)
|
||||||
else:
|
else:
|
||||||
final_df.to_parquet(target_path)
|
final_df.to_parquet(target_path)
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ class RunningQuantileStats:
|
|||||||
batch: An array where all dimensions except the last are batch dimensions.
|
batch: An array where all dimensions except the last are batch dimensions.
|
||||||
"""
|
"""
|
||||||
batch = batch.reshape(-1, batch.shape[-1])
|
batch = batch.reshape(-1, batch.shape[-1])
|
||||||
|
# Promote integer and low-precision inputs before computing squared statistics.
|
||||||
|
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
|
||||||
num_elements, vector_length = batch.shape
|
num_elements, vector_length = batch.shape
|
||||||
|
|
||||||
if self._count == 0:
|
if self._count == 0:
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -709,7 +710,7 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
features = {**features, **DEFAULT_FEATURES}
|
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
|
|
||||||
obj.tasks = None
|
obj.tasks = None
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ class DatasetReader:
|
|||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self._tolerance_s = tolerance_s
|
self._tolerance_s = tolerance_s
|
||||||
self._video_backend = video_backend
|
self._video_backend = video_backend
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
self._image_transforms = image_transforms
|
self._image_transforms = image_transforms
|
||||||
self._return_uint8 = return_uint8
|
self._return_uint8 = return_uint8
|
||||||
|
|
||||||
@@ -86,6 +88,16 @@ class DatasetReader:
|
|||||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||||
|
|
||||||
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
|
"""Replace the transform applied to visual observations."""
|
||||||
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
|
self._image_transforms = image_transforms
|
||||||
|
|
||||||
|
def clear_image_transforms(self) -> None:
|
||||||
|
"""Remove the transform applied to visual observations."""
|
||||||
|
self._image_transforms = None
|
||||||
|
|
||||||
def try_load(self) -> bool:
|
def try_load(self) -> bool:
|
||||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import logging
|
|||||||
import shutil
|
import shutil
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -261,6 +262,8 @@ def merge_datasets(
|
|||||||
datasets: list[LeRobotDataset],
|
datasets: list[LeRobotDataset],
|
||||||
output_repo_id: str,
|
output_repo_id: str,
|
||||||
output_dir: str | Path | None = None,
|
output_dir: str | Path | None = None,
|
||||||
|
concatenate_videos: bool = True,
|
||||||
|
concatenate_data: bool = True,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
"""Merge multiple LeRobotDatasets into a single dataset.
|
"""Merge multiple LeRobotDatasets into a single dataset.
|
||||||
|
|
||||||
@@ -270,6 +273,8 @@ def merge_datasets(
|
|||||||
datasets: List of LeRobotDatasets to merge.
|
datasets: List of LeRobotDatasets to merge.
|
||||||
output_repo_id: Merged dataset identifier.
|
output_repo_id: Merged dataset identifier.
|
||||||
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||||
|
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
|
||||||
|
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
|
||||||
"""
|
"""
|
||||||
if not datasets:
|
if not datasets:
|
||||||
raise ValueError("No datasets to merge")
|
raise ValueError("No datasets to merge")
|
||||||
@@ -284,6 +289,8 @@ def merge_datasets(
|
|||||||
aggr_repo_id=output_repo_id,
|
aggr_repo_id=output_repo_id,
|
||||||
roots=roots,
|
roots=roots,
|
||||||
aggr_root=output_dir,
|
aggr_root=output_dir,
|
||||||
|
concatenate_videos=concatenate_videos,
|
||||||
|
concatenate_data=concatenate_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
merged_dataset = LeRobotDataset(
|
merged_dataset = LeRobotDataset(
|
||||||
@@ -1095,7 +1102,9 @@ def _copy_episodes_metadata_and_stats(
|
|||||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||||
for key in dst_meta.video_keys:
|
for key in dst_meta.video_keys:
|
||||||
if key in src_dataset.meta.features:
|
if key in src_dataset.meta.features:
|
||||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
dst_meta.info.features[key]["info"] = deepcopy(
|
||||||
|
src_dataset.meta.info.features[key].get("info", {})
|
||||||
|
)
|
||||||
|
|
||||||
write_info(dst_meta.info, dst_meta.root)
|
write_info(dst_meta.info, dst_meta.root)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -130,3 +131,81 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def make_train_eval_datasets(
|
||||||
|
cfg: TrainPipelineConfig,
|
||||||
|
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
|
||||||
|
"""Create train and optional eval datasets by splitting episodes based on eval_split.
|
||||||
|
|
||||||
|
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
|
||||||
|
If eval_split == 0.0, returns (full_dataset, None).
|
||||||
|
"""
|
||||||
|
full_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
if cfg.dataset.eval_split == 0.0:
|
||||||
|
return full_dataset, None
|
||||||
|
|
||||||
|
base_episodes = (
|
||||||
|
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_tasks = full_dataset.meta.episodes["tasks"]
|
||||||
|
task_to_episodes: dict[str, list[int]] = {}
|
||||||
|
for ep_idx in base_episodes:
|
||||||
|
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
|
||||||
|
task_to_episodes.setdefault(task_key, []).append(ep_idx)
|
||||||
|
|
||||||
|
train_episodes, eval_episodes = [], []
|
||||||
|
for eps in task_to_episodes.values():
|
||||||
|
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
|
||||||
|
train_episodes.extend(eps[: len(eps) - n_eval])
|
||||||
|
eval_episodes.extend(eps[len(eps) - n_eval :])
|
||||||
|
|
||||||
|
if not train_episodes:
|
||||||
|
raise ValueError(
|
||||||
|
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
|
||||||
|
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
|
||||||
|
)
|
||||||
|
|
||||||
|
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
|
||||||
|
|
||||||
|
train_image_transforms = (
|
||||||
|
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = LeRobotDataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
episodes=train_episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=train_image_transforms,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
return_uint8=True,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = LeRobotDataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
episodes=eval_episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=None,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
return_uint8=True,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.dataset.use_imagenet_stats:
|
||||||
|
for ds in (train_dataset, eval_dataset):
|
||||||
|
for key in ds.meta.camera_keys:
|
||||||
|
for stats_type, stats in IMAGENET_STATS.items():
|
||||||
|
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||||
|
|
||||||
|
return train_dataset, eval_dataset
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas
|
import pandas
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
import pyarrow.dataset as pa_ds
|
import pyarrow.dataset as pa_ds
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
@@ -153,7 +154,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: The statistics dictionary with values cast to numpy arrays.
|
dict: The statistics dictionary with values cast to numpy arrays.
|
||||||
"""
|
"""
|
||||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
||||||
return unflatten_dict(stats)
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
|
|
||||||
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
|||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
|
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
|
||||||
|
"""Write ``table`` with one parquet row group per episode (in episode order).
|
||||||
|
|
||||||
|
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
|
||||||
|
mirroring the recording writer. ``table`` must carry a contiguous
|
||||||
|
``episode_index`` column.
|
||||||
|
"""
|
||||||
|
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
|
||||||
|
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
|
||||||
|
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
|
||||||
|
try:
|
||||||
|
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
|
||||||
|
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
|
||||||
|
finally:
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
def to_parquet_with_hf_images(
|
def to_parquet_with_hf_images(
|
||||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
|
||||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
|
||||||
|
|
||||||
Args:
|
Images are embedded into the arrow table first (``ParquetWriter.write_table``
|
||||||
df: DataFrame to write to parquet.
|
does not embed external image files like ``Dataset.to_parquet`` does).
|
||||||
path: Path to write the parquet file.
|
``features`` types image columns as ``Image()`` in the parquet schema.
|
||||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
|
||||||
are properly typed as Image() in the parquet schema.
|
|
||||||
"""
|
"""
|
||||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
|
||||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||||
ds.to_parquet(path)
|
ds = embed_images(ds)
|
||||||
|
table = ds.with_format("arrow")[:]
|
||||||
|
if "episode_index" in table.column_names:
|
||||||
|
write_table_one_row_group_per_episode(table, path)
|
||||||
|
else:
|
||||||
|
# No episode boundaries to align row groups to — keep a single write.
|
||||||
|
pq.write_table(table, str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
|
||||||
|
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
|
||||||
|
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||||
|
if "episode_index" in table.column_names:
|
||||||
|
write_table_one_row_group_per_episode(table, path)
|
||||||
|
else:
|
||||||
|
pq.write_table(table, str(path))
|
||||||
|
|
||||||
|
|
||||||
def item_to_torch(item: dict) -> dict:
|
def item_to_torch(item: dict) -> dict:
|
||||||
|
|||||||
@@ -201,8 +201,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self._requested_root = Path(root) if root else None
|
self._requested_root = Path(root) if root else None
|
||||||
self.reader = None
|
|
||||||
self.set_image_transforms(image_transforms)
|
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
@@ -249,6 +247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
return_uint8=self._return_uint8,
|
return_uint8=self._return_uint8,
|
||||||
)
|
)
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
if force_cache_sync or not self.reader.try_load():
|
if force_cache_sync or not self.reader.try_load():
|
||||||
@@ -370,6 +369,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.reader.load_and_activate()
|
self.reader.load_and_activate()
|
||||||
return self.reader.hf_dataset
|
return self.reader.hf_dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def absolute_to_relative_idx(self) -> dict[int, int] | None:
|
||||||
|
"""Mapping from absolute frame indices to HF dataset row positions.
|
||||||
|
|
||||||
|
Non-None only for episode-filtered datasets where absolute indices
|
||||||
|
(from metadata) differ from row positions in the loaded HF dataset.
|
||||||
|
"""
|
||||||
|
reader = self._ensure_reader()
|
||||||
|
if reader.hf_dataset is None:
|
||||||
|
reader.load_and_activate()
|
||||||
|
return reader._absolute_to_relative_idx
|
||||||
|
|
||||||
# ── Writer-delegated methods ──────────────────────────────────────
|
# ── Writer-delegated methods ──────────────────────────────────────
|
||||||
|
|
||||||
def add_frame(self, frame: dict) -> None:
|
def add_frame(self, frame: dict) -> None:
|
||||||
@@ -505,15 +516,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
"""Replace the transform applied to visual observations."""
|
"""Replace the transform applied to visual observations."""
|
||||||
if image_transforms is not None and not callable(image_transforms):
|
self._ensure_reader().set_image_transforms(image_transforms)
|
||||||
raise TypeError("image_transforms must be callable or None.")
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
if self.reader is not None:
|
|
||||||
self.reader._image_transforms = image_transforms
|
|
||||||
|
|
||||||
def clear_image_transforms(self) -> None:
|
def clear_image_transforms(self) -> None:
|
||||||
"""Remove the transform applied to visual observations."""
|
"""Remove the transform applied to visual observations."""
|
||||||
self.set_image_transforms(None)
|
if self.reader is not None:
|
||||||
|
self.reader.set_image_transforms(None)
|
||||||
|
self.image_transforms = None
|
||||||
|
|
||||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||||
|
|
||||||
@@ -524,7 +534,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
license: str | None = "apache-2.0",
|
license: str | None = "apache-2.0",
|
||||||
tag_version: bool = True,
|
tag_version: bool = True,
|
||||||
push_videos: bool = True,
|
push_videos: bool = True,
|
||||||
private: bool = False,
|
private: bool | None = None,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
upload_large_folder: bool = False,
|
upload_large_folder: bool = False,
|
||||||
**card_kwargs,
|
**card_kwargs,
|
||||||
@@ -543,7 +553,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
tag_version: If ``True``, create a Git tag for the current codebase
|
tag_version: If ``True``, create a Git tag for the current codebase
|
||||||
version.
|
version.
|
||||||
push_videos: If ``False``, skip uploading the ``videos/`` directory.
|
push_videos: If ``False``, skip uploading the ``videos/`` directory.
|
||||||
private: If ``True``, create a private repository.
|
private: If ``True``, create a private repository. If ``None``
|
||||||
|
(default), defer to the org default on the Hub (only affects orgs).
|
||||||
allow_patterns: Glob pattern(s) restricting which files to upload.
|
allow_patterns: Glob pattern(s) restricting which files to upload.
|
||||||
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
|
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
|
||||||
of ``upload_folder`` for very large datasets.
|
of ``upload_folder`` for very large datasets.
|
||||||
|
|||||||
@@ -70,19 +70,21 @@ def aggregate_pipeline_dataset_features(
|
|||||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||||
*,
|
*,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
|
exclude_images: bool = False,
|
||||||
patterns: Sequence[str] | None = None,
|
patterns: Sequence[str] | None = None,
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||||
|
|
||||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
(image or state), filters them based on `exclude_images` and `patterns`, and finally
|
||||||
formats them for use with a Hugging Face LeRobot Dataset.
|
formats them for use with a Hugging Face LeRobot Dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipeline: The DataProcessorPipeline to apply.
|
pipeline: The DataProcessorPipeline to apply.
|
||||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||||
use_videos: If False, image features are excluded.
|
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
|
||||||
|
exclude_images: If True, image features are dropped entirely from the output.
|
||||||
patterns: A sequence of regex patterns to filter action and state features.
|
patterns: A sequence of regex patterns to filter action and state features.
|
||||||
Image features are not affected by this filter.
|
Image features are not affected by this filter.
|
||||||
|
|
||||||
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. Apply filtering rules.
|
# 2. Apply filtering rules.
|
||||||
if is_image and not use_videos:
|
if is_image and exclude_images:
|
||||||
continue
|
continue
|
||||||
if not is_image and not should_keep(key, compiled_patterns):
|
if not is_image and not should_keep(key, compiled_patterns):
|
||||||
continue
|
continue
|
||||||
|
|||||||
+127
-32
@@ -14,14 +14,36 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EpisodeAwareSampler:
|
class EpisodeAwareSampler:
|
||||||
|
"""Sampler over episode frames that stores only per-episode boundaries.
|
||||||
|
|
||||||
|
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
|
||||||
|
instead of materializing a Python list of every frame index.
|
||||||
|
|
||||||
|
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
|
||||||
|
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
|
||||||
|
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
|
||||||
|
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
|
||||||
|
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
|
||||||
|
resumed epoch, `__len__` still reports the full length.
|
||||||
|
|
||||||
|
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
|
||||||
|
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
|
||||||
|
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
|
||||||
|
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
|
||||||
|
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
|
||||||
|
starts (e.g. on resume or in tests).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_from_indices: list[int],
|
dataset_from_indices: list[int],
|
||||||
@@ -30,57 +52,130 @@ class EpisodeAwareSampler:
|
|||||||
drop_n_first_frames: int = 0,
|
drop_n_first_frames: int = 0,
|
||||||
drop_n_last_frames: int = 0,
|
drop_n_last_frames: int = 0,
|
||||||
shuffle: bool = False,
|
shuffle: bool = False,
|
||||||
|
seed: int = 0,
|
||||||
|
absolute_to_relative_idx: dict[int, int] | None = None,
|
||||||
):
|
):
|
||||||
"""Sampler that optionally incorporates episode boundary information.
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
dataset_from_indices: Start index of each episode in the dataset.
|
||||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
dataset_to_indices: End index of each episode in the dataset.
|
||||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
episode_indices_to_use: Episode indices to use; None means all.
|
||||||
Assumes that episodes are indexed from 0 to N-1.
|
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
|
||||||
shuffle: Whether to shuffle the indices.
|
shuffle: Whether to shuffle the indices.
|
||||||
|
seed: Seed the permutation is derived from (together with the epoch).
|
||||||
"""
|
"""
|
||||||
if drop_n_first_frames < 0:
|
if drop_n_first_frames < 0:
|
||||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||||
if drop_n_last_frames < 0:
|
if drop_n_last_frames < 0:
|
||||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||||
|
|
||||||
indices = []
|
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||||
for episode_idx, (start_index, end_index) in enumerate(
|
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
if from_indices.shape != to_indices.shape:
|
||||||
):
|
raise ValueError(
|
||||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
f"dataset_from_indices and dataset_to_indices must have the same length, "
|
||||||
ep_length = end_index - start_index
|
f"got {len(from_indices)} and {len(to_indices)}"
|
||||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
)
|
||||||
logger.warning(
|
|
||||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
|
||||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
|
||||||
episode_idx,
|
|
||||||
ep_length,
|
|
||||||
drop_n_first_frames,
|
|
||||||
drop_n_last_frames,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
|
||||||
|
|
||||||
if not indices:
|
used = np.ones(len(from_indices), dtype=bool)
|
||||||
|
if episode_indices_to_use is not None:
|
||||||
|
used = np.zeros(len(from_indices), dtype=bool)
|
||||||
|
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
|
||||||
|
|
||||||
|
starts = from_indices + drop_n_first_frames
|
||||||
|
lengths = to_indices - drop_n_last_frames - starts
|
||||||
|
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
|
||||||
|
logger.warning(
|
||||||
|
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||||
|
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||||
|
episode_idx,
|
||||||
|
to_indices[episode_idx] - from_indices[episode_idx],
|
||||||
|
drop_n_first_frames,
|
||||||
|
drop_n_last_frames,
|
||||||
|
)
|
||||||
|
used &= lengths > 0
|
||||||
|
if not used.any():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||||
"All episodes were either filtered out or had too few frames."
|
"All episodes were either filtered out or had too few frames."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.indices = indices
|
self._starts = starts[used]
|
||||||
|
self._cum_lengths = np.cumsum(lengths[used])
|
||||||
|
self._num_frames = int(self._cum_lengths[-1])
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
self.seed = seed
|
||||||
|
self._epoch = 0
|
||||||
|
self._start_index = 0
|
||||||
|
self._absolute_to_relative = absolute_to_relative_idx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def indices(self) -> list[int]:
|
||||||
|
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||||
|
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int) -> None:
|
||||||
|
self._epoch = epoch
|
||||||
|
|
||||||
|
def state_dict(self) -> dict:
|
||||||
|
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict) -> None:
|
||||||
|
self._epoch = state["epoch"]
|
||||||
|
self._start_index = state["start_index"]
|
||||||
|
|
||||||
|
def _epoch_generator(self, epoch: int) -> torch.Generator:
|
||||||
|
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
|
||||||
|
# and reproduces identically on every rank without touching the global RNG.
|
||||||
|
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
|
||||||
|
return torch.Generator().manual_seed(epoch_seed)
|
||||||
|
|
||||||
|
def _frame_index(self, position: int) -> int:
|
||||||
|
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||||
|
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||||
|
absolute_idx = int(self._starts[episode]) + position_in_episode
|
||||||
|
if self._absolute_to_relative is not None:
|
||||||
|
return self._absolute_to_relative[absolute_idx]
|
||||||
|
return absolute_idx
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[int]:
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||||
|
epoch, start = self._epoch, self._start_index
|
||||||
|
self._epoch += 1
|
||||||
|
self._start_index = 0
|
||||||
|
return self._iter_epoch(epoch, start)
|
||||||
|
|
||||||
|
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
for i in torch.randperm(len(self.indices)):
|
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
|
||||||
yield self.indices[i]
|
for k in range(start, self._num_frames):
|
||||||
|
yield self._frame_index(int(order[k]))
|
||||||
else:
|
else:
|
||||||
for i in self.indices:
|
for k in range(start, self._num_frames):
|
||||||
yield i
|
yield self._frame_index(k)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.indices)
|
return self._num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||||
|
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||||
|
|
||||||
|
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||||
|
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||||
|
per epoch (`even_batches` padding included). The start index provably stays below
|
||||||
|
`num_frames`; the `min` is defensive.
|
||||||
|
|
||||||
|
Assumptions (resume is only sample-exact when they hold):
|
||||||
|
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||||
|
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||||
|
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
|
||||||
|
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||||
|
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||||
|
the boundary is off.
|
||||||
|
"""
|
||||||
|
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||||
|
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||||
|
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
|
||||||
|
return {"epoch": epoch, "start_index": start_index}
|
||||||
|
|||||||
@@ -481,8 +481,10 @@ def reencode_video(
|
|||||||
encoder_threads: int | None = None,
|
encoder_threads: int | None = None,
|
||||||
log_level: int | None = av.logging.WARNING,
|
log_level: int | None = av.logging.WARNING,
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
|
start_time_s: float | None = None,
|
||||||
|
end_time_s: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Re-encode a video file using the given encoder configuration.
|
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_video_path: Existing video file to read.
|
input_video_path: Existing video file to read.
|
||||||
@@ -491,10 +493,17 @@ def reencode_video(
|
|||||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||||
|
start_time_s: When set, trim the output to start at this timestamp (seconds).
|
||||||
|
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||||
|
|
||||||
|
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
|
||||||
|
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
|
||||||
|
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
|
||||||
|
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
|
||||||
|
|
||||||
output_video_path = Path(output_video_path)
|
output_video_path = Path(output_video_path)
|
||||||
|
|
||||||
if output_video_path.exists() and not overwrite:
|
if output_video_path.exists() and not overwrite:
|
||||||
@@ -526,6 +535,10 @@ def reencode_video(
|
|||||||
width = int(in_stream.width)
|
width = int(in_stream.width)
|
||||||
height = int(in_stream.height)
|
height = int(in_stream.height)
|
||||||
|
|
||||||
|
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
|
||||||
|
if start_time_s is not None:
|
||||||
|
src.seek(int(start_time_s * av.time_base), backward=True)
|
||||||
|
|
||||||
with av.open(
|
with av.open(
|
||||||
tmp_output_video_path,
|
tmp_output_video_path,
|
||||||
mode="w",
|
mode="w",
|
||||||
@@ -539,7 +552,14 @@ def reencode_video(
|
|||||||
out_stream.height = height
|
out_stream.height = height
|
||||||
|
|
||||||
for frame in src.decode(in_stream):
|
for frame in src.decode(in_stream):
|
||||||
|
frame_time_s = frame.time
|
||||||
|
if start_time_s is not None and frame_time_s < start_time_s:
|
||||||
|
continue
|
||||||
|
if end_time_s is not None and frame_time_s >= end_time_s:
|
||||||
|
break
|
||||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||||
|
if start_time_s is not None:
|
||||||
|
frame.pts = None # reset timestamps so the trimmed output starts at t=0
|
||||||
packet = out_stream.encode(frame)
|
packet = out_stream.encode(frame)
|
||||||
if packet:
|
if packet:
|
||||||
dst.mux(packet)
|
dst.mux(packet)
|
||||||
|
|||||||
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
if "camera_obs" in observations:
|
if "camera_obs" in observations:
|
||||||
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
||||||
|
|
||||||
|
# Pass through any remaining ndarray/tensor keys not already handled above,
|
||||||
|
# so env plugins can expose extra observation keys via get_env_processors().
|
||||||
|
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
|
||||||
|
for key, value in observations.items():
|
||||||
|
if key in _handled:
|
||||||
|
continue
|
||||||
|
target = f"{OBS_STR}.{key}"
|
||||||
|
if target in return_observations:
|
||||||
|
continue
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
val = torch.from_numpy(value).float()
|
||||||
|
if val.dim() == 1:
|
||||||
|
val = val.unsqueeze(0)
|
||||||
|
return_observations[target] = val
|
||||||
|
elif isinstance(value, Tensor):
|
||||||
|
val = value.float()
|
||||||
|
if val.dim() == 1:
|
||||||
|
val = val.unsqueeze(0)
|
||||||
|
return_observations[target] = val
|
||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from .optimizers import (
|
|||||||
SGDConfig as SGDConfig,
|
SGDConfig as SGDConfig,
|
||||||
XVLAAdamWConfig as XVLAAdamWConfig,
|
XVLAAdamWConfig as XVLAAdamWConfig,
|
||||||
load_optimizer_state,
|
load_optimizer_state,
|
||||||
|
load_optimizer_state_dict,
|
||||||
save_optimizer_state,
|
save_optimizer_state,
|
||||||
)
|
)
|
||||||
from .schedulers import (
|
from .schedulers import (
|
||||||
@@ -50,6 +51,7 @@ __all__ = [
|
|||||||
"VQBeTSchedulerConfig",
|
"VQBeTSchedulerConfig",
|
||||||
# State management
|
# State management
|
||||||
"load_optimizer_state",
|
"load_optimizer_state",
|
||||||
|
"load_optimizer_state_dict",
|
||||||
"load_scheduler_state",
|
"load_scheduler_state",
|
||||||
"save_optimizer_state",
|
"save_optimizer_state",
|
||||||
"save_scheduler_state",
|
"save_scheduler_state",
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from lerobot.utils.constants import (
|
|||||||
OPTIMIZER_PARAM_GROUPS,
|
OPTIMIZER_PARAM_GROUPS,
|
||||||
OPTIMIZER_STATE,
|
OPTIMIZER_STATE,
|
||||||
)
|
)
|
||||||
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
|
from lerobot.utils.io_utils import deserialize_json_into_object, load_json, write_json
|
||||||
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
# Type alias for parameters accepted by optimizer build() methods.
|
# Type alias for parameters accepted by optimizer build() methods.
|
||||||
@@ -281,28 +281,37 @@ class MultiAdamConfig(OptimizerConfig):
|
|||||||
|
|
||||||
|
|
||||||
def save_optimizer_state(
|
def save_optimizer_state(
|
||||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer],
|
||||||
|
save_dir: Path,
|
||||||
|
optim_state_dict: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save optimizer state to disk.
|
"""Save optimizer state to disk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||||
save_dir: Directory to save the optimizer state.
|
save_dir: Directory to save the optimizer state.
|
||||||
|
optim_state_dict: Pre-gathered optimizer state dict (for FSDP, where the sharded state must
|
||||||
|
be gathered across ranks first). If provided, it is saved directly instead of calling
|
||||||
|
``optimizer.state_dict()``. Only supported for a single optimizer. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if isinstance(optimizer, dict):
|
if isinstance(optimizer, dict):
|
||||||
# Handle dictionary of optimizers
|
# Handle dictionary of optimizers
|
||||||
|
if optim_state_dict is not None:
|
||||||
|
raise ValueError("optim_state_dict is not supported for a dict of optimizers")
|
||||||
for name, opt in optimizer.items():
|
for name, opt in optimizer.items():
|
||||||
optimizer_dir = save_dir / name
|
optimizer_dir = save_dir / name
|
||||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||||
_save_single_optimizer_state(opt, optimizer_dir)
|
_save_single_optimizer_state(opt, optimizer_dir)
|
||||||
else:
|
else:
|
||||||
# Handle single optimizer
|
# Handle single optimizer
|
||||||
_save_single_optimizer_state(optimizer, save_dir)
|
_save_single_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
|
||||||
|
|
||||||
|
|
||||||
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
def _save_single_optimizer_state(
|
||||||
|
optimizer: torch.optim.Optimizer, save_dir: Path, optim_state_dict: dict | None = None
|
||||||
|
) -> None:
|
||||||
"""Save a single optimizer's state to disk."""
|
"""Save a single optimizer's state to disk."""
|
||||||
state = optimizer.state_dict()
|
state = dict(optim_state_dict) if optim_state_dict is not None else optimizer.state_dict()
|
||||||
param_groups = state.pop("param_groups")
|
param_groups = state.pop("param_groups")
|
||||||
flat_state = flatten_dict(state)
|
flat_state = flatten_dict(state)
|
||||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||||
@@ -356,3 +365,19 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
|||||||
|
|
||||||
optimizer.load_state_dict(loaded_state_dict)
|
optimizer.load_state_dict(loaded_state_dict)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_optimizer_state_dict(save_dir: Path) -> dict:
|
||||||
|
"""Read a saved optimizer state dict (safetensors + json) back into a plain dict.
|
||||||
|
|
||||||
|
Unlike `load_optimizer_state`, this does not load into an optimizer and preserves the original
|
||||||
|
``state`` keys verbatim (e.g. FSDP parameter FQNs, which are not integer-castable). It is used by
|
||||||
|
the FSDP resume path, where the full state must be resharded via `FSDP.optim_state_dict_to_load`
|
||||||
|
before being loaded into the (sharded) optimizer.
|
||||||
|
"""
|
||||||
|
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||||
|
state = unflatten_dict(flat_state)
|
||||||
|
return {
|
||||||
|
"state": state.get("state", {}),
|
||||||
|
"param_groups": load_json(save_dir / OPTIMIZER_PARAM_GROUPS),
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
|||||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||||
|
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||||
@@ -43,6 +44,7 @@ __all__ = [
|
|||||||
"EO1Config",
|
"EO1Config",
|
||||||
"GaussianActorConfig",
|
"GaussianActorConfig",
|
||||||
"GrootConfig",
|
"GrootConfig",
|
||||||
|
"MolmoAct2Config",
|
||||||
"MultiTaskDiTConfig",
|
"MultiTaskDiTConfig",
|
||||||
"PI0Config",
|
"PI0Config",
|
||||||
"PI0FastConfig",
|
"PI0FastConfig",
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae and log_sigma_x2_hat is not None:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
|
|||||||
@@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations.
|
||||||
# stack n latest observations from the queue
|
|
||||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
|
||||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- Online (queues populated via select_action): stacks observations from internal queues.
|
||||||
|
- Offline (empty queues, e.g. dataloader batch): uses the batch directly.
|
||||||
|
"""
|
||||||
|
queues_populated = any(len(q) > 0 for q in self._queues.values())
|
||||||
|
if queues_populated:
|
||||||
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
|
else:
|
||||||
|
batch = dict(batch)
|
||||||
|
if self.config.image_features:
|
||||||
|
for key in self.config.image_features:
|
||||||
|
if batch[key].ndim == 4:
|
||||||
|
batch[key] = batch[key].unsqueeze(1)
|
||||||
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
|
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
|||||||
from .eo1.configuration_eo1 import EO1Config
|
from .eo1.configuration_eo1 import EO1Config
|
||||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||||
from .groot.configuration_groot import GrootConfig
|
from .groot.configuration_groot import GrootConfig
|
||||||
|
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config
|
from .pi0.configuration_pi0 import PI0Config
|
||||||
from .pi05.configuration_pi05 import PI05Config
|
from .pi05.configuration_pi05 import PI05Config
|
||||||
@@ -56,6 +57,7 @@ from .pretrained import PreTrainedPolicy
|
|||||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from .utils import validate_visual_features_consistency
|
from .utils import validate_visual_features_consistency
|
||||||
|
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from .wall_x.configuration_wall_x import WallXConfig
|
from .wall_x.configuration_wall_x import WallXConfig
|
||||||
from .xvla.configuration_xvla import XVLAConfig
|
from .xvla.configuration_xvla import XVLAConfig
|
||||||
@@ -88,7 +90,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
|
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
|
||||||
|
"molmoact2".
|
||||||
Returns:
|
Returns:
|
||||||
The policy class corresponding to the given name.
|
The policy class corresponding to the given name.
|
||||||
|
|
||||||
@@ -151,6 +154,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from .eo1.modeling_eo1 import EO1Policy
|
from .eo1.modeling_eo1 import EO1Policy
|
||||||
|
|
||||||
return EO1Policy
|
return EO1Policy
|
||||||
|
elif name == "molmoact2":
|
||||||
|
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||||
|
|
||||||
|
return MolmoAct2Policy
|
||||||
|
elif name == "vla_jepa":
|
||||||
|
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||||
|
|
||||||
|
return VLAJEPAPolicy
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return _get_policy_cls_from_policy_name(name=name)
|
return _get_policy_cls_from_policy_name(name=name)
|
||||||
@@ -168,7 +179,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
Args:
|
Args:
|
||||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||||
"smolvla", "wall_x".
|
"smolvla", "wall_x", "molmoact2".
|
||||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -203,6 +214,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return WallXConfig(**kwargs)
|
return WallXConfig(**kwargs)
|
||||||
elif policy_type == "eo1":
|
elif policy_type == "eo1":
|
||||||
return EO1Config(**kwargs)
|
return EO1Config(**kwargs)
|
||||||
|
elif policy_type == "molmoact2":
|
||||||
|
return MolmoAct2Config(**kwargs)
|
||||||
|
elif policy_type == "vla_jepa":
|
||||||
|
return VLAJEPAConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||||
@@ -231,11 +246,13 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
|||||||
preprocessor_overrides: dict[str, Any] | None
|
preprocessor_overrides: dict[str, Any] | None
|
||||||
postprocessor_overrides: dict[str, Any] | None
|
postprocessor_overrides: dict[str, Any] | None
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||||
|
dataset_meta: Any | None
|
||||||
|
|
||||||
|
|
||||||
def make_pre_post_processors(
|
def make_pre_post_processors(
|
||||||
policy_cfg: PreTrainedConfig,
|
policy_cfg: PreTrainedConfig,
|
||||||
pretrained_path: str | None = None,
|
pretrained_path: str | None = None,
|
||||||
|
pretrained_revision: str | None = None,
|
||||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
@@ -293,6 +310,7 @@ def make_pre_post_processors(
|
|||||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||||
to_transition=batch_to_transition,
|
to_transition=batch_to_transition,
|
||||||
to_output=transition_to_batch,
|
to_output=transition_to_batch,
|
||||||
|
revision=pretrained_revision,
|
||||||
)
|
)
|
||||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path=pretrained_path,
|
pretrained_model_name_or_path=pretrained_path,
|
||||||
@@ -302,6 +320,7 @@ def make_pre_post_processors(
|
|||||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||||
to_transition=policy_action_to_transition,
|
to_transition=policy_action_to_transition,
|
||||||
to_output=transition_to_policy_action,
|
to_output=transition_to_policy_action,
|
||||||
|
revision=pretrained_revision,
|
||||||
)
|
)
|
||||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||||
return preprocessor, postprocessor
|
return preprocessor, postprocessor
|
||||||
@@ -406,6 +425,7 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, EO1Config):
|
elif isinstance(policy_cfg, EO1Config):
|
||||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||||
|
|
||||||
@@ -414,6 +434,23 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif isinstance(policy_cfg, MolmoAct2Config):
|
||||||
|
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_molmoact2_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
dataset_meta=kwargs.get("dataset_meta"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||||
|
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_vla_jepa_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
processors = _make_processors_from_policy_config(
|
processors = _make_processors_from_policy_config(
|
||||||
@@ -499,6 +536,10 @@ def make_policy(
|
|||||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||||
if action_names is not None:
|
if action_names is not None:
|
||||||
cfg.action_feature_names = list(action_names)
|
cfg.action_feature_names = list(action_names)
|
||||||
|
if ds_meta is not None:
|
||||||
|
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
|
||||||
|
if callable(set_dataset_feature_metadata):
|
||||||
|
set_dataset_feature_metadata(ds_meta.features)
|
||||||
|
|
||||||
kwargs["config"] = cfg
|
kwargs["config"] = cfg
|
||||||
|
|
||||||
@@ -519,6 +560,7 @@ def make_policy(
|
|||||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||||
# hyperparameters that we want to vary).
|
# hyperparameters that we want to vary).
|
||||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||||
|
kwargs["revision"] = cfg.pretrained_revision
|
||||||
policy = policy_cls.from_pretrained(**kwargs)
|
policy = policy_cls.from_pretrained(**kwargs)
|
||||||
elif cfg.pretrained_path and cfg.use_peft:
|
elif cfg.pretrained_path and cfg.use_peft:
|
||||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
|
|||||||
"SiglipEncoderLayer",
|
"SiglipEncoderLayer",
|
||||||
]
|
]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_flash_attn = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
|
|||||||
"videos_kwargs",
|
"videos_kwargs",
|
||||||
"text_kwargs",
|
"text_kwargs",
|
||||||
]
|
]
|
||||||
image_processor_class = "AutoImageProcessor"
|
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
|
|||||||
"Vendor files are copied during model creation. Create the policy/model first, "
|
"Vendor files are copied during model creation. Create the policy/model first, "
|
||||||
"or call ensure_eagle_cache_ready() before building processors."
|
"or call ensure_eagle_cache_ready() before building processors."
|
||||||
)
|
)
|
||||||
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
|
proc = AutoProcessor.from_pretrained(
|
||||||
|
str(cache_dir),
|
||||||
|
trust_remote_code=True,
|
||||||
|
fix_mistral_regex=False,
|
||||||
|
)
|
||||||
proc.tokenizer.padding_side = "left"
|
proc.tokenizer.padding_side = "left"
|
||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
../../../../docs/source/policy_molmoact2_README.md
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .configuration_molmoact2 import MolmoAct2Config
|
||||||
|
from .modeling_molmoact2 import MolmoAct2Policy
|
||||||
|
from .processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]
|
||||||
@@ -0,0 +1,519 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from contextlib import suppress
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||||
|
from lerobot.optim import (
|
||||||
|
AdamWConfig,
|
||||||
|
CosineDecayWithWarmupSchedulerConfig,
|
||||||
|
LRSchedulerConfig,
|
||||||
|
OptimizerConfig,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
|
||||||
|
from ..rtc.configuration_rtc import RTCConfig
|
||||||
|
|
||||||
|
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
|
||||||
|
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
|
||||||
|
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
|
||||||
|
MOLMOACT2_TASK_TOKEN_BUDGET = 32
|
||||||
|
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
|
||||||
|
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
|
||||||
|
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
|
||||||
|
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
|
||||||
|
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
|
||||||
|
|
||||||
|
|
||||||
|
def _hf_token() -> str | None:
|
||||||
|
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_checkpoint_location(
|
||||||
|
checkpoint_path: str,
|
||||||
|
*,
|
||||||
|
revision: str | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
) -> str:
|
||||||
|
checkpoint_path = str(checkpoint_path or "").strip()
|
||||||
|
if not checkpoint_path:
|
||||||
|
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
|
||||||
|
local_path = Path(checkpoint_path).expanduser()
|
||||||
|
if local_path.exists():
|
||||||
|
return str(local_path)
|
||||||
|
return snapshot_download(
|
||||||
|
repo_id=checkpoint_path,
|
||||||
|
repo_type="model",
|
||||||
|
revision=revision,
|
||||||
|
force_download=force_download,
|
||||||
|
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||||
|
token=_hf_token(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_hf_norm_metadata_for_tag(
|
||||||
|
checkpoint_path: str,
|
||||||
|
*,
|
||||||
|
revision: str | None,
|
||||||
|
force_download: bool,
|
||||||
|
norm_tag: str | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
norm_tag = str(norm_tag or "").strip()
|
||||||
|
if not norm_tag:
|
||||||
|
return {}
|
||||||
|
checkpoint_location = Path(
|
||||||
|
_resolve_checkpoint_location(
|
||||||
|
checkpoint_path,
|
||||||
|
revision=revision,
|
||||||
|
force_download=force_download,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
norm_stats_filename = "norm_stats.json"
|
||||||
|
config_path = checkpoint_location / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
with suppress(OSError, json.JSONDecodeError):
|
||||||
|
norm_stats_filename = str(
|
||||||
|
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
|
||||||
|
)
|
||||||
|
stats_path = checkpoint_location / norm_stats_filename
|
||||||
|
if not stats_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
|
||||||
|
)
|
||||||
|
payload = json.loads(stats_path.read_text())
|
||||||
|
metadata_by_tag = payload.get("metadata_by_tag")
|
||||||
|
if not isinstance(metadata_by_tag, dict):
|
||||||
|
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
|
||||||
|
metadata = metadata_by_tag.get(norm_tag)
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
available = sorted(str(tag) for tag in metadata_by_tag)
|
||||||
|
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
|
||||||
|
@dataclass
|
||||||
|
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
|
||||||
|
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
|
||||||
|
|
||||||
|
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
|
||||||
|
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
|
||||||
|
training steps"; build() is the first point where num_training_steps is known.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_decay_steps: int | None
|
||||||
|
|
||||||
|
def build(self, optimizer, num_training_steps: int):
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.peak_lr,
|
||||||
|
decay_lr=self.decay_lr,
|
||||||
|
num_warmup_steps=self.num_warmup_steps,
|
||||||
|
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
|
||||||
|
).build(optimizer, num_training_steps=num_training_steps)
|
||||||
|
|
||||||
|
|
||||||
|
def _round_up(value: int, multiple: int) -> int:
|
||||||
|
return int(math.ceil(value / multiple) * multiple)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_molmoact2_max_sequence_length(
|
||||||
|
*,
|
||||||
|
num_images: int,
|
||||||
|
state_dim: int,
|
||||||
|
action_dim: int,
|
||||||
|
action_horizon: int,
|
||||||
|
include_discrete_action: bool,
|
||||||
|
) -> int:
|
||||||
|
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
|
||||||
|
if num_images < 1:
|
||||||
|
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||||
|
if state_dim < 0:
|
||||||
|
state_dim = 0
|
||||||
|
if action_dim < 1:
|
||||||
|
action_dim = 1
|
||||||
|
if action_horizon < 1:
|
||||||
|
action_horizon = 1
|
||||||
|
|
||||||
|
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
|
||||||
|
prompt_tokens = (
|
||||||
|
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
|
||||||
|
+ MOLMOACT2_TASK_TOKEN_BUDGET
|
||||||
|
+ state_dim
|
||||||
|
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
|
||||||
|
)
|
||||||
|
action_tokens = 0
|
||||||
|
if include_discrete_action:
|
||||||
|
action_tokens_per_step = max(
|
||||||
|
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
|
||||||
|
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
|
||||||
|
)
|
||||||
|
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
|
||||||
|
|
||||||
|
return _round_up(
|
||||||
|
image_tokens + prompt_tokens + action_tokens,
|
||||||
|
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("molmoact2")
|
||||||
|
@dataclass
|
||||||
|
class MolmoAct2Config(PreTrainedConfig):
|
||||||
|
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
|
||||||
|
|
||||||
|
checkpoint_path: str = "allenai/MolmoAct2"
|
||||||
|
checkpoint_revision: str | None = None
|
||||||
|
checkpoint_force_download: bool = False
|
||||||
|
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 30
|
||||||
|
n_action_steps: int = 30
|
||||||
|
|
||||||
|
action_mode: str = "both"
|
||||||
|
inference_action_mode: str | None = None
|
||||||
|
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
|
||||||
|
discrete_generation_max_steps: int | None = None
|
||||||
|
norm_tag: str | None = None
|
||||||
|
|
||||||
|
setup_type: str = ""
|
||||||
|
control_mode: str = ""
|
||||||
|
image_keys: list[str] = field(default_factory=list)
|
||||||
|
normalize_language: bool = True
|
||||||
|
add_setup_tokens: bool = True
|
||||||
|
add_control_tokens: bool = True
|
||||||
|
normalize_gripper: bool = False
|
||||||
|
num_state_tokens: int = 256
|
||||||
|
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
|
||||||
|
# image/prompt/state/action token layout. Override only for unusual long prompts.
|
||||||
|
max_sequence_length: int | None = None
|
||||||
|
|
||||||
|
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
|
||||||
|
expected_max_action_dim: int = 32
|
||||||
|
|
||||||
|
# Flow-matching training knobs copied from the original MolmoAct2 training path.
|
||||||
|
num_flow_timesteps: int = 8
|
||||||
|
flow_matching_cutoff: float = 1.0
|
||||||
|
flow_matching_time_offset: float = 0.001
|
||||||
|
flow_matching_time_scale: float = 0.999
|
||||||
|
flow_matching_beta_alpha: float = 1.0
|
||||||
|
flow_matching_beta_beta: float = 1.5
|
||||||
|
num_inference_steps: int | None = None
|
||||||
|
mask_action_dim_padding: bool = True
|
||||||
|
enable_inference_cuda_graph: bool = True
|
||||||
|
# MolmoAct2-local eval option. When enabled, stochastic continuous action
|
||||||
|
# generation uses a rollout-local generator derived from eval_seed.
|
||||||
|
per_episode_seed: bool = False
|
||||||
|
eval_seed: int | None = None
|
||||||
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
|
# Default is full finetuning with gradients from the action expert flowing into the VLM.
|
||||||
|
enable_lora_vlm: bool = False
|
||||||
|
lora_rank: int = 64
|
||||||
|
lora_alpha: int = 16
|
||||||
|
lora_dropout: float = 0.05
|
||||||
|
lora_bias: str = "none"
|
||||||
|
enable_lora_action_expert: bool = False
|
||||||
|
enable_knowledge_insulation: bool = False
|
||||||
|
freeze_embedding: bool = True
|
||||||
|
train_action_expert_only: bool = False
|
||||||
|
gradient_checkpointing: bool = False
|
||||||
|
|
||||||
|
model_dtype: str = "bfloat16"
|
||||||
|
softmax_auxiliary_loss: bool = True
|
||||||
|
softmax_auxiliary_loss_scale: float = 1e-4
|
||||||
|
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
|
||||||
|
|
||||||
|
optimizer_lr: float = 1e-5
|
||||||
|
optimizer_vit_lr: float = 5e-6
|
||||||
|
optimizer_connector_lr: float = 5e-6
|
||||||
|
optimizer_action_expert_lr: float = 5e-5
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
optimizer_eps: float = 1e-6
|
||||||
|
optimizer_weight_decay: float = 0.0
|
||||||
|
optimizer_grad_clip_norm: float = 1.0
|
||||||
|
|
||||||
|
scheduler_warmup_steps: int = 200
|
||||||
|
scheduler_decay_steps: int | None = None
|
||||||
|
scheduler_decay_lr: float = 1e-6
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.QUANTILES,
|
||||||
|
"ACTION": NormalizationMode.QUANTILES,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__post_init__()
|
||||||
|
if self.action_mode not in {"continuous", "discrete", "both"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported action_mode={self.action_mode!r}. "
|
||||||
|
"Expected one of {'continuous', 'discrete', 'both'}."
|
||||||
|
)
|
||||||
|
if self.inference_action_mode not in {None, "continuous", "discrete"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
|
||||||
|
"Expected one of {None, 'continuous', 'discrete'}."
|
||||||
|
)
|
||||||
|
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
|
||||||
|
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
|
||||||
|
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
|
||||||
|
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
|
||||||
|
if self.train_action_expert_only and self.action_mode != "continuous":
|
||||||
|
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
|
||||||
|
if self.train_action_expert_only and self.enable_lora_vlm:
|
||||||
|
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
|
||||||
|
if self.enable_lora_action_expert and not self.enable_lora_vlm:
|
||||||
|
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
|
||||||
|
if self.chunk_size < 1:
|
||||||
|
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
|
||||||
|
if self.n_action_steps < 1:
|
||||||
|
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
|
||||||
|
)
|
||||||
|
if self.expected_max_action_dim != 32:
|
||||||
|
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
|
||||||
|
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
|
||||||
|
)
|
||||||
|
if self.lora_rank < 1:
|
||||||
|
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
|
||||||
|
if self.lora_alpha < 1:
|
||||||
|
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
|
||||||
|
if not 0 <= self.lora_dropout <= 1:
|
||||||
|
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
|
||||||
|
if self.lora_bias not in {"none", "all", "lora_only"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
|
||||||
|
)
|
||||||
|
if self.discrete_loss_token_weighting not in {
|
||||||
|
"none",
|
||||||
|
"token",
|
||||||
|
"root_tokens",
|
||||||
|
"root_subsegments",
|
||||||
|
"root_subsegments_root_tokens",
|
||||||
|
}:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
|
||||||
|
)
|
||||||
|
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
|
||||||
|
)
|
||||||
|
if self.max_sequence_length is not None and self.max_sequence_length < 1:
|
||||||
|
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
|
||||||
|
|
||||||
|
def inferred_max_sequence_length(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_images: int | None = None,
|
||||||
|
state_dim: int | None = None,
|
||||||
|
action_dim: int | None = None,
|
||||||
|
action_horizon: int | None = None,
|
||||||
|
include_discrete_action: bool | None = None,
|
||||||
|
) -> int:
|
||||||
|
if self.max_sequence_length is not None:
|
||||||
|
return int(self.max_sequence_length)
|
||||||
|
|
||||||
|
if num_images is None:
|
||||||
|
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
|
||||||
|
if state_dim is None:
|
||||||
|
state_feature = self.robot_state_feature
|
||||||
|
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
|
||||||
|
if action_dim is None:
|
||||||
|
action_feature = self.action_feature
|
||||||
|
action_dim = (
|
||||||
|
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
|
||||||
|
)
|
||||||
|
if action_horizon is None:
|
||||||
|
action_horizon = self.chunk_size
|
||||||
|
if include_discrete_action is None:
|
||||||
|
include_discrete_action = self.action_mode in {"discrete", "both"}
|
||||||
|
|
||||||
|
return infer_molmoact2_max_sequence_length(
|
||||||
|
num_images=int(num_images),
|
||||||
|
state_dim=int(state_dim),
|
||||||
|
action_dim=int(action_dim),
|
||||||
|
action_horizon=int(action_horizon),
|
||||||
|
include_discrete_action=bool(include_discrete_action),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||||
|
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.optimizer_lr,
|
||||||
|
decay_lr=self.scheduler_decay_lr,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
num_decay_steps=self.scheduler_decay_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
|
||||||
|
self.dataset_feature_names = {}
|
||||||
|
for key in (ACTION, OBS_STATE):
|
||||||
|
feature = features.get(key) if isinstance(features, dict) else None
|
||||||
|
if isinstance(feature, dict) and feature.get("names") is not None:
|
||||||
|
self.dataset_feature_names[key] = feature["names"]
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
"""Validate and set up MolmoAct2 input and output features."""
|
||||||
|
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||||
|
if not image_features:
|
||||||
|
raise ValueError(
|
||||||
|
"MolmoAct2 policy requires at least one visual input feature. "
|
||||||
|
"No features of type FeatureType.VISUAL found in input_features."
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_STATE not in self.input_features:
|
||||||
|
state_feature = PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(0,),
|
||||||
|
)
|
||||||
|
self.input_features[OBS_STATE] = state_feature
|
||||||
|
|
||||||
|
if ACTION not in self.output_features:
|
||||||
|
action_feature = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(self.expected_max_action_dim,),
|
||||||
|
)
|
||||||
|
self.output_features[ACTION] = action_feature
|
||||||
|
|
||||||
|
def apply_norm_tag_metadata(self) -> None:
|
||||||
|
if not str(self.norm_tag or "").strip():
|
||||||
|
return
|
||||||
|
metadata = _load_hf_norm_metadata_for_tag(
|
||||||
|
self.checkpoint_path,
|
||||||
|
revision=self.checkpoint_revision,
|
||||||
|
force_download=bool(self.checkpoint_force_download),
|
||||||
|
norm_tag=self.norm_tag,
|
||||||
|
)
|
||||||
|
if metadata.get("action_horizon") is not None:
|
||||||
|
self.chunk_size = int(metadata["action_horizon"])
|
||||||
|
if metadata.get("n_action_steps") is not None:
|
||||||
|
self.n_action_steps = int(metadata["n_action_steps"])
|
||||||
|
if not self.setup_type and metadata.get("setup_type") is not None:
|
||||||
|
self.setup_type = str(metadata["setup_type"])
|
||||||
|
if not self.control_mode and metadata.get("control_mode") is not None:
|
||||||
|
self.control_mode = str(metadata["control_mode"])
|
||||||
|
|
||||||
|
def saved_policy_action_mode(self) -> str | None:
|
||||||
|
pretrained_path = getattr(self, "pretrained_path", None)
|
||||||
|
if pretrained_path is None:
|
||||||
|
return None
|
||||||
|
config_path = Path(pretrained_path) / "config.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
mode = json.loads(config_path.read_text()).get("action_mode")
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return None
|
||||||
|
if mode in {"continuous", "discrete", "both"}:
|
||||||
|
return str(mode)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
|
||||||
|
return saved_policy_action_mode or self.action_mode
|
||||||
|
|
||||||
|
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
|
||||||
|
requested_mode = self.inference_action_mode
|
||||||
|
if requested_mode is None:
|
||||||
|
return
|
||||||
|
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||||
|
if requested_mode == "continuous" and training_mode == "discrete":
|
||||||
|
raise ValueError(
|
||||||
|
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
|
||||||
|
"continuous inference."
|
||||||
|
)
|
||||||
|
if requested_mode == "discrete" and training_mode == "continuous":
|
||||||
|
raise ValueError(
|
||||||
|
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
|
||||||
|
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_checkpoint_action_mode(
|
||||||
|
self,
|
||||||
|
checkpoint_action_mode: str,
|
||||||
|
*,
|
||||||
|
has_action_expert: bool,
|
||||||
|
) -> None:
|
||||||
|
if self.action_mode == "both" and checkpoint_action_mode != "both":
|
||||||
|
raise ValueError(
|
||||||
|
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
|
||||||
|
)
|
||||||
|
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
|
||||||
|
f"got {checkpoint_action_mode!r}."
|
||||||
|
)
|
||||||
|
if self.action_mode in {"continuous", "both"} and not has_action_expert:
|
||||||
|
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
|
||||||
|
|
||||||
|
def resolve_inference_action_mode(
|
||||||
|
self,
|
||||||
|
requested_mode: str | None,
|
||||||
|
saved_policy_action_mode: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
training_mode = self.training_action_mode(saved_policy_action_mode)
|
||||||
|
if requested_mode is None:
|
||||||
|
requested_mode = self.inference_action_mode
|
||||||
|
if requested_mode is None:
|
||||||
|
raise ValueError(
|
||||||
|
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
|
||||||
|
"to either 'continuous' or 'discrete'."
|
||||||
|
)
|
||||||
|
if requested_mode not in {"continuous", "discrete"}:
|
||||||
|
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
|
||||||
|
if requested_mode == "continuous" and training_mode == "discrete":
|
||||||
|
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
|
||||||
|
if requested_mode == "discrete" and training_mode == "continuous":
|
||||||
|
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
|
||||||
|
return requested_mode
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
@@ -0,0 +1,237 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tokenizers import ByteLevelBPETokenizer
|
||||||
|
from tokenizers.trainers import BpeTrainer
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
from transformers.processing_utils import ProcessorMixin
|
||||||
|
|
||||||
|
|
||||||
|
def _hf_token() -> str | None:
|
||||||
|
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tokenizer_location(
|
||||||
|
tokenizer_path: str,
|
||||||
|
*,
|
||||||
|
revision: str | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
) -> str:
|
||||||
|
local_path = Path(str(tokenizer_path)).expanduser()
|
||||||
|
if local_path.exists():
|
||||||
|
return str(local_path)
|
||||||
|
return snapshot_download(
|
||||||
|
repo_id=str(tokenizer_path),
|
||||||
|
repo_type="model",
|
||||||
|
revision=revision,
|
||||||
|
force_download=force_download,
|
||||||
|
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
|
||||||
|
token=_hf_token(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalActionProcessor(ProcessorMixin):
|
||||||
|
attributes: ClassVar[list[str]] = ["tokenizer"]
|
||||||
|
tokenizer_class: str = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerFast,
|
||||||
|
scale: float = 10,
|
||||||
|
vocab_size: int = 1024,
|
||||||
|
min_token: int = 0,
|
||||||
|
*,
|
||||||
|
action_dim: int | None = None,
|
||||||
|
time_horizon: int | None = None,
|
||||||
|
):
|
||||||
|
self.scale = scale
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.min_token = min_token
|
||||||
|
|
||||||
|
# Action horizon and dimension needed during decoding. These can be specified
|
||||||
|
# in three ways (in order of priority):
|
||||||
|
# 1. passed in as kwargs to decode()
|
||||||
|
# 2. in the constructor
|
||||||
|
# 3. cached from the last time decode() was called
|
||||||
|
self.time_horizon = time_horizon
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.called_time_horizon = time_horizon
|
||||||
|
self.called_action_dim = action_dim
|
||||||
|
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
self.bpe_tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
def __call__(self, action_chunk: np.array) -> np.array:
|
||||||
|
from scipy.fft import dct
|
||||||
|
|
||||||
|
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
|
||||||
|
if action_chunk.ndim == 2:
|
||||||
|
action_chunk = action_chunk[None, ...]
|
||||||
|
|
||||||
|
# Cache the time horizon and action dimension for decoding
|
||||||
|
self.called_time_horizon = action_chunk.shape[-2]
|
||||||
|
self.called_action_dim = action_chunk.shape[-1]
|
||||||
|
|
||||||
|
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
|
||||||
|
dct_coeff = np.around(dct_coeff * self.scale)
|
||||||
|
tokens = []
|
||||||
|
for elem in dct_coeff:
|
||||||
|
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
|
||||||
|
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
tokens: list[list[int]],
|
||||||
|
*,
|
||||||
|
time_horizon: int | None = None,
|
||||||
|
action_dim: int | None = None,
|
||||||
|
) -> np.array:
|
||||||
|
from scipy.fft import idct
|
||||||
|
|
||||||
|
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
|
||||||
|
self.action_dim = action_dim or self.action_dim or self.called_action_dim
|
||||||
|
|
||||||
|
# Cache the time horizon and action dimension for the next call
|
||||||
|
self.called_time_horizon = self.time_horizon
|
||||||
|
self.called_action_dim = self.action_dim
|
||||||
|
|
||||||
|
assert self.time_horizon is not None and self.action_dim is not None, (
|
||||||
|
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded_actions = []
|
||||||
|
for token in tokens:
|
||||||
|
try:
|
||||||
|
decoded_tokens = self.bpe_tokenizer.decode(token)
|
||||||
|
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
|
||||||
|
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||||
|
assert decoded_dct_coeff.shape == (
|
||||||
|
self.time_horizon,
|
||||||
|
self.action_dim,
|
||||||
|
), (
|
||||||
|
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding tokens: {e}")
|
||||||
|
print(f"Tokens: {token}")
|
||||||
|
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||||
|
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
|
||||||
|
return np.stack(decoded_actions)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fit(
|
||||||
|
cls,
|
||||||
|
action_data: list[np.array],
|
||||||
|
scale: float = 10,
|
||||||
|
vocab_size: int = 1024,
|
||||||
|
*,
|
||||||
|
time_horizon: int | None = None,
|
||||||
|
action_dim: int | None = None,
|
||||||
|
) -> "UniversalActionProcessor":
|
||||||
|
from scipy.fft import dct
|
||||||
|
|
||||||
|
# Run DCT over all inputs
|
||||||
|
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
|
||||||
|
|
||||||
|
# Quantize and find min token
|
||||||
|
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
|
||||||
|
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
|
||||||
|
min_vocab_size = max_token - min_token
|
||||||
|
|
||||||
|
assert min_vocab_size <= vocab_size, (
|
||||||
|
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
|
||||||
|
)
|
||||||
|
if min_vocab_size + 100 > vocab_size:
|
||||||
|
logging.warning(
|
||||||
|
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
|
||||||
|
f"size {vocab_size}, consider increasing vocab size"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make token iterator for BPE training
|
||||||
|
def _token_iter():
|
||||||
|
for tokens in dct_tokens:
|
||||||
|
rounded_tokens = np.around(tokens * scale) - min_token
|
||||||
|
rounded_tokens = rounded_tokens.astype(int)
|
||||||
|
string = "".join(map(chr, rounded_tokens))
|
||||||
|
yield string
|
||||||
|
|
||||||
|
# Train BPE tokenizer
|
||||||
|
bpe = ByteLevelBPETokenizer()
|
||||||
|
|
||||||
|
# Set up the entire range of possible tokens as the initial alphabet
|
||||||
|
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
|
||||||
|
trainer = BpeTrainer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
min_frequency=2,
|
||||||
|
show_progress=True,
|
||||||
|
special_tokens=[],
|
||||||
|
initial_alphabet=alphabet,
|
||||||
|
max_token_length=10000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
|
||||||
|
# because it doesn't support custom alphabets)
|
||||||
|
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
|
||||||
|
scale=scale,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
min_token=min_token,
|
||||||
|
time_horizon=time_horizon,
|
||||||
|
action_dim=action_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained_local(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
*,
|
||||||
|
revision: str | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
) -> "UniversalActionProcessor":
|
||||||
|
location = Path(
|
||||||
|
_resolve_tokenizer_location(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
revision=revision,
|
||||||
|
force_download=force_download,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
processor_config = {}
|
||||||
|
processor_config_path = location / "processor_config.json"
|
||||||
|
if processor_config_path.exists():
|
||||||
|
import json
|
||||||
|
|
||||||
|
processor_config = json.loads(processor_config_path.read_text())
|
||||||
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
|
||||||
|
return cls(
|
||||||
|
tokenizer,
|
||||||
|
scale=processor_config.get("scale", 10),
|
||||||
|
vocab_size=processor_config.get("vocab_size", 1024),
|
||||||
|
min_token=processor_config.get("min_token", 0),
|
||||||
|
action_dim=processor_config.get("action_dim"),
|
||||||
|
time_horizon=processor_config.get("time_horizon"),
|
||||||
|
)
|
||||||
@@ -0,0 +1,553 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
|
||||||
|
"""
|
||||||
|
MolmoAct2 configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.modeling_rope_utils import rope_config_validation
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2VitConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
|
||||||
|
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
|
||||||
|
defining the model architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2VitConfig
|
||||||
|
>>> configuration = MolmoAct2VitConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
|
||||||
|
>>> model = MolmoAct2VisionTransformer(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "molmoact2"
|
||||||
|
base_config_key = "vit_config"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 1152,
|
||||||
|
intermediate_size: int = 4304,
|
||||||
|
num_hidden_layers: int = 27,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 16,
|
||||||
|
head_dim: int = 72,
|
||||||
|
hidden_act: str = "gelu_pytorch_tanh",
|
||||||
|
layer_norm_eps: float = 1e-6,
|
||||||
|
image_default_input_size: tuple[int, int] = (378, 378),
|
||||||
|
image_patch_size: int = 14,
|
||||||
|
image_num_pos: int = 577,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
residual_dropout: float = 0.0,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
float32_attention: bool = True,
|
||||||
|
attn_implementation: str = "eager",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.image_default_input_size = image_default_input_size
|
||||||
|
self.image_patch_size = image_patch_size
|
||||||
|
self.image_num_pos = image_num_pos
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.residual_dropout = residual_dropout
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.float32_attention = float32_attention
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_num_patch(self):
|
||||||
|
h, w = self.image_default_input_size
|
||||||
|
return h // self.image_patch_size, w // self.image_patch_size
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2AdapterConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
|
||||||
|
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
|
||||||
|
defining the model architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
|
||||||
|
>>> vit_config = MolmoAct2VitConfig()
|
||||||
|
>>> adapter_config = MolmoPoolingConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
|
||||||
|
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> vit_configuration = model.vit_config
|
||||||
|
>>> adapter_configuration = model.adapter_config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "molmoact2"
|
||||||
|
base_config_key = "adapter_config"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_layers: tuple = (-3, -9),
|
||||||
|
pooling_attention_mask: bool = False,
|
||||||
|
hidden_size: int = 1152,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads: int = 16,
|
||||||
|
head_dim: int = 72,
|
||||||
|
float32_attention: bool = True,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
residual_dropout: float = 0.0,
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
intermediate_size: int = 18944,
|
||||||
|
text_hidden_size: int = 3584,
|
||||||
|
image_feature_dropout: float = 0.0,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
attn_implementation: str = "eager",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
super().__init__(attn_implementation=attn_implementation, **kwargs)
|
||||||
|
self.vit_layers = vit_layers
|
||||||
|
self.pooling_attention_mask = pooling_attention_mask
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.float32_attention = float32_attention
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.residual_dropout = residual_dropout
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.text_hidden_size = text_hidden_size
|
||||||
|
self.image_feature_dropout = image_feature_dropout
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2TextConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
|
||||||
|
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2TextConfig
|
||||||
|
>>> configuration = MolmoAct2TextConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2TextModel (with random weights)
|
||||||
|
>>> model = MolmoAct2TextModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "molmoact2_text"
|
||||||
|
base_config_key = "text_config"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"blocks.*.self_attn.att_proj": "colwise",
|
||||||
|
"blocks.*.self_attn.attn_out": "rowwise",
|
||||||
|
"blocks.*.mlp.ff_proj": "colwise",
|
||||||
|
"blocks.*.mlp.ff_out": "rowwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"wte": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"ln_f": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 3584,
|
||||||
|
num_attention_heads: int = 28,
|
||||||
|
num_key_value_heads: int | None = 4,
|
||||||
|
head_dim: int = 128,
|
||||||
|
vocab_size: int = 152064,
|
||||||
|
additional_vocab_size: int = 128,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
num_hidden_layers: int = 48,
|
||||||
|
intermediate_size: int = 18944,
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
embedding_dropout: float = 0.0,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
residual_dropout: float = 0.0,
|
||||||
|
max_position_embeddings: int = 4096,
|
||||||
|
rope_theta: float = 1000000.0,
|
||||||
|
rope_scaling: dict[str, Any] = None,
|
||||||
|
rope_scaling_layers: list[int] | None = None,
|
||||||
|
use_qk_norm: bool = False,
|
||||||
|
qk_norm_type: str = "olmo",
|
||||||
|
layer_norm_eps: int = 1e-6,
|
||||||
|
norm_after: bool = False,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
use_cache=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
attn_implementation: str = "eager",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.attn_implementation = attn_implementation
|
||||||
|
super().__init__(
|
||||||
|
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
|
||||||
|
)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.additional_vocab_size = additional_vocab_size
|
||||||
|
self.qkv_bias = qkv_bias
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.embedding_dropout = embedding_dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.residual_dropout = residual_dropout
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.rope_scaling_layers = rope_scaling_layers
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
|
self.qk_norm_type = qk_norm_type
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.norm_after = norm_after
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.use_cache = use_cache
|
||||||
|
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2ActionExpertConfig(PretrainedConfig):
|
||||||
|
r"""Configuration for the MolmoAct2 modern action expert."""
|
||||||
|
|
||||||
|
model_type = "molmoact2_action_expert"
|
||||||
|
base_config_key = "action_expert_config"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_action_horizon: int = 32,
|
||||||
|
max_action_dim: int = 32,
|
||||||
|
hidden_size: int = 1024,
|
||||||
|
num_layers: int = 32,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 8.0 / 3.0,
|
||||||
|
ffn_multiple_of: int = 256,
|
||||||
|
timestep_embed_dim: int = 256,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
attn_dropout: float = 0.0,
|
||||||
|
context_layer_norm: bool = True,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
qk_norm_eps: float = 1e-6,
|
||||||
|
rope: bool = True,
|
||||||
|
causal_attn: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.max_action_horizon = max_action_horizon
|
||||||
|
self.max_action_dim = max_action_dim
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.ffn_multiple_of = ffn_multiple_of
|
||||||
|
self.timestep_embed_dim = timestep_embed_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_dropout = attn_dropout
|
||||||
|
self.context_layer_norm = context_layer_norm
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
self.qk_norm_eps = qk_norm_eps
|
||||||
|
self.rope = rope
|
||||||
|
self.causal_attn = causal_attn
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
output = super().to_dict()
|
||||||
|
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
|
||||||
|
# them out of the public nested config avoids duplicated sources of truth.
|
||||||
|
output.pop("max_action_horizon", None)
|
||||||
|
output.pop("max_action_dim", None)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
|
||||||
|
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2VitConfig
|
||||||
|
>>> vit_config = MolmoAct2VitConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2AdapterConfig
|
||||||
|
>>> adapter_config = MolmoAct2AdapterConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2TextConfig
|
||||||
|
>>> text_config = MolmoAct2TextConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a MolmoAct2Config
|
||||||
|
>>> configuration = MolmoAct2Config(
|
||||||
|
>>> vit_config=vit_config,
|
||||||
|
>>> adapter_config=adapter_config,
|
||||||
|
>>> text_config=text_config,
|
||||||
|
>>> image_start_token_id=151936,
|
||||||
|
>>> image_end_token_id=151937,
|
||||||
|
>>> image_patch_id=151938,
|
||||||
|
>>> image_col_id=151939,
|
||||||
|
>>> low_res_image_start_token_id=151940,
|
||||||
|
>>> image_low_res_id=151942,
|
||||||
|
>>> frame_start_token_id=151943,
|
||||||
|
>>> frame_end_token_id=151944,
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
>>> # Initializing a model
|
||||||
|
>>> model = MolmoAct2ForConditionalGeneration(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "molmoact2"
|
||||||
|
sub_configs = {
|
||||||
|
"text_config": MolmoAct2TextConfig,
|
||||||
|
"vit_config": MolmoAct2VitConfig,
|
||||||
|
"adapter_config": MolmoAct2AdapterConfig,
|
||||||
|
"action_expert_config": MolmoAct2ActionExpertConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_config: MolmoAct2VitConfig = None,
|
||||||
|
adapter_config: MolmoAct2AdapterConfig = None,
|
||||||
|
text_config: MolmoAct2TextConfig = None,
|
||||||
|
action_expert_config: MolmoAct2ActionExpertConfig = None,
|
||||||
|
image_start_token_id: int = None,
|
||||||
|
low_res_image_start_token_id: int = None,
|
||||||
|
image_end_token_id: int = None,
|
||||||
|
image_low_res_id: int = None,
|
||||||
|
image_patch_id: int = None,
|
||||||
|
image_col_id: int = None,
|
||||||
|
frame_start_token_id: int = None,
|
||||||
|
frame_end_token_id: int = None,
|
||||||
|
use_frame_special_tokens: bool = True,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
add_action_expert: bool = True,
|
||||||
|
max_action_dim: int = 32,
|
||||||
|
max_action_horizon: int = 30,
|
||||||
|
n_obs_steps: int = 30,
|
||||||
|
action_mode: str = "both",
|
||||||
|
state_format: str = "discrete",
|
||||||
|
flow_matching_num_steps: int = 10,
|
||||||
|
flow_matching_cutoff: float = 1.0,
|
||||||
|
flow_matching_time_offset: float = 0.001,
|
||||||
|
flow_matching_time_scale: float = 0.999,
|
||||||
|
flow_matching_beta_alpha: float = 1.0,
|
||||||
|
flow_matching_beta_beta: float = 1.5,
|
||||||
|
mask_action_dim_padding: bool = True,
|
||||||
|
enable_depth_reasoning: bool = False,
|
||||||
|
depth_mode: int = 2,
|
||||||
|
num_depth_codes: int = 100,
|
||||||
|
action_expert_depth_gate: bool = False,
|
||||||
|
action_expert_depth_gate_per_layer: bool = False,
|
||||||
|
action_expert_depth_gate_init_bias: float = -4.0,
|
||||||
|
action_output_token_id: int = None,
|
||||||
|
action_start_token_id: int = None,
|
||||||
|
action_end_token_id: int = None,
|
||||||
|
action_token_start_id: int = None,
|
||||||
|
num_action_tokens: int = 0,
|
||||||
|
depth_output_token_id: int = None,
|
||||||
|
depth_start_token_id: int = None,
|
||||||
|
depth_end_token_id: int = None,
|
||||||
|
depth_token_start_id: int = None,
|
||||||
|
num_depth_tokens: int = 0,
|
||||||
|
state_start_token_id: int = None,
|
||||||
|
state_end_token_id: int = None,
|
||||||
|
state_token_start_id: int = None,
|
||||||
|
num_state_tokens: int = 0,
|
||||||
|
add_setup_tokens: bool = True,
|
||||||
|
add_control_tokens: bool = True,
|
||||||
|
norm_stats_filename: str = "norm_stats.json",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if vit_config is None:
|
||||||
|
self.vit_config = MolmoAct2VitConfig()
|
||||||
|
elif isinstance(vit_config, dict):
|
||||||
|
self.vit_config = MolmoAct2VitConfig(**vit_config)
|
||||||
|
else:
|
||||||
|
self.vit_config = vit_config
|
||||||
|
if adapter_config is None:
|
||||||
|
self.adapter_config = MolmoAct2AdapterConfig()
|
||||||
|
elif isinstance(adapter_config, dict):
|
||||||
|
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
|
||||||
|
else:
|
||||||
|
self.adapter_config = adapter_config
|
||||||
|
if text_config is None:
|
||||||
|
self.text_config = MolmoAct2TextConfig()
|
||||||
|
elif isinstance(text_config, dict):
|
||||||
|
self.text_config = MolmoAct2TextConfig(**text_config)
|
||||||
|
else:
|
||||||
|
self.text_config = text_config
|
||||||
|
self.add_action_expert = bool(add_action_expert)
|
||||||
|
if not self.add_action_expert:
|
||||||
|
self.action_expert_config = None
|
||||||
|
elif action_expert_config is None:
|
||||||
|
self.action_expert_config = MolmoAct2ActionExpertConfig(
|
||||||
|
max_action_horizon=max_action_horizon,
|
||||||
|
max_action_dim=max_action_dim,
|
||||||
|
num_layers=self.text_config.num_hidden_layers,
|
||||||
|
)
|
||||||
|
elif isinstance(action_expert_config, dict):
|
||||||
|
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
|
||||||
|
else:
|
||||||
|
self.action_expert_config = action_expert_config
|
||||||
|
if self.add_action_expert:
|
||||||
|
self.action_expert_config.max_action_dim = int(max_action_dim)
|
||||||
|
self.action_expert_config.max_action_horizon = int(max_action_horizon)
|
||||||
|
self._validate_release_action_config(
|
||||||
|
state_format=state_format,
|
||||||
|
)
|
||||||
|
self.image_start_token_id = image_start_token_id
|
||||||
|
self.low_res_image_start_token_id = low_res_image_start_token_id
|
||||||
|
self.image_end_token_id = image_end_token_id
|
||||||
|
self.image_low_res_id = image_low_res_id
|
||||||
|
self.image_high_res_id = image_patch_id
|
||||||
|
self.image_patch_id = image_patch_id
|
||||||
|
self.image_col_id = image_col_id
|
||||||
|
self.frame_start_token_id = frame_start_token_id
|
||||||
|
self.frame_end_token_id = frame_end_token_id
|
||||||
|
self.use_frame_special_tokens = use_frame_special_tokens
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.max_action_dim = max_action_dim
|
||||||
|
self.max_action_horizon = max_action_horizon
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.action_mode = action_mode
|
||||||
|
self.state_format = state_format
|
||||||
|
self.flow_matching_num_steps = flow_matching_num_steps
|
||||||
|
self.flow_matching_cutoff = flow_matching_cutoff
|
||||||
|
self.flow_matching_time_offset = flow_matching_time_offset
|
||||||
|
self.flow_matching_time_scale = flow_matching_time_scale
|
||||||
|
self.flow_matching_beta_alpha = flow_matching_beta_alpha
|
||||||
|
self.flow_matching_beta_beta = flow_matching_beta_beta
|
||||||
|
self.mask_action_dim_padding = mask_action_dim_padding
|
||||||
|
self.enable_depth_reasoning = enable_depth_reasoning
|
||||||
|
self.depth_mode = depth_mode
|
||||||
|
self.num_depth_codes = num_depth_codes
|
||||||
|
self.action_expert_depth_gate = action_expert_depth_gate
|
||||||
|
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
|
||||||
|
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
|
||||||
|
self.action_output_token_id = action_output_token_id
|
||||||
|
self.action_start_token_id = action_start_token_id
|
||||||
|
self.action_end_token_id = action_end_token_id
|
||||||
|
self.action_token_start_id = action_token_start_id
|
||||||
|
self.num_action_tokens = num_action_tokens
|
||||||
|
self.depth_output_token_id = depth_output_token_id
|
||||||
|
self.depth_start_token_id = depth_start_token_id
|
||||||
|
self.depth_end_token_id = depth_end_token_id
|
||||||
|
self.depth_token_start_id = depth_token_start_id
|
||||||
|
self.num_depth_tokens = num_depth_tokens
|
||||||
|
self.state_start_token_id = state_start_token_id
|
||||||
|
self.state_end_token_id = state_end_token_id
|
||||||
|
self.state_token_start_id = state_token_start_id
|
||||||
|
self.num_state_tokens = num_state_tokens
|
||||||
|
self.add_setup_tokens = add_setup_tokens
|
||||||
|
self.add_control_tokens = add_control_tokens
|
||||||
|
self.norm_stats_filename = norm_stats_filename
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_release_action_config(
|
||||||
|
*,
|
||||||
|
state_format: str,
|
||||||
|
) -> None:
|
||||||
|
if state_format != "discrete":
|
||||||
|
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_num_patch(self):
|
||||||
|
assert self.vit_config is not None
|
||||||
|
return self.vit_config.image_num_patch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_attention_heads(self):
|
||||||
|
return self.text_config.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_key_value_heads(self):
|
||||||
|
return self.text_config.num_key_value_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.text_config.head_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_hidden_layers(self):
|
||||||
|
return self.text_config.num_hidden_layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self):
|
||||||
|
return self.text_config.hidden_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return self.text_config.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_position_embeddings(self):
|
||||||
|
return self.text_config.max_position_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
MolmoAct2VitConfig.register_for_auto_class()
|
||||||
|
MolmoAct2AdapterConfig.register_for_auto_class()
|
||||||
|
MolmoAct2TextConfig.register_for_auto_class()
|
||||||
|
MolmoAct2ActionExpertConfig.register_for_auto_class()
|
||||||
|
MolmoAct2Config.register_for_auto_class()
|
||||||
@@ -0,0 +1,564 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
|
||||||
|
"""Image processor class for MolmoAct2"""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
import numpy as np
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms
|
||||||
|
|
||||||
|
from transformers.image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
make_flat_list_of_images,
|
||||||
|
valid_images,
|
||||||
|
to_numpy_array,
|
||||||
|
)
|
||||||
|
from transformers.image_transforms import convert_to_rgb
|
||||||
|
from transformers.processing_utils import ImagesKwargs
|
||||||
|
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
|
from transformers.utils import TensorType, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_image(
|
||||||
|
image: np.ndarray,
|
||||||
|
image_mean: list[float],
|
||||||
|
image_std: list[float],
|
||||||
|
) -> np.ndarray:
|
||||||
|
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
|
||||||
|
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
|
||||||
|
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
|
||||||
|
image /= np.array(image_std, dtype=np.float32)[None, None, :]
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(
|
||||||
|
image: np.ndarray,
|
||||||
|
desired_output_size: list[int],
|
||||||
|
resample: PILImageResampling,
|
||||||
|
) -> np.ndarray:
|
||||||
|
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
|
||||||
|
dtype = image.dtype
|
||||||
|
if torch.is_floating_point(image):
|
||||||
|
in_min = 0.0
|
||||||
|
in_max = 1.0
|
||||||
|
resized = torchvision.transforms.Resize(
|
||||||
|
desired_output_size,
|
||||||
|
resample,
|
||||||
|
antialias=False,
|
||||||
|
)(image)
|
||||||
|
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
|
||||||
|
else:
|
||||||
|
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
|
||||||
|
image.dtype
|
||||||
|
)
|
||||||
|
in_min = 0.0
|
||||||
|
in_max = 255.0
|
||||||
|
resized = torchvision.transforms.Resize(
|
||||||
|
desired_output_size,
|
||||||
|
resample,
|
||||||
|
antialias=False,
|
||||||
|
)(image)
|
||||||
|
resized = torch.clip(resized, 0, 255).to(dtype)
|
||||||
|
|
||||||
|
resized = resized.to(torch.float32)
|
||||||
|
resized = (resized - in_min) / (in_max - in_min)
|
||||||
|
|
||||||
|
resized = torch.permute(resized, [1, 2, 0]).numpy()
|
||||||
|
|
||||||
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
def select_tiling(h, w, patch_size, max_num_crops):
|
||||||
|
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
||||||
|
original_size = np.stack([h, w]) # [1, 2]
|
||||||
|
original_res = h * w
|
||||||
|
tilings = []
|
||||||
|
for i in range(1, max_num_crops + 1):
|
||||||
|
for j in range(1, max_num_crops + 1):
|
||||||
|
if i * j <= max_num_crops:
|
||||||
|
tilings.append((i, j))
|
||||||
|
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
||||||
|
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
|
||||||
|
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
|
||||||
|
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
||||||
|
|
||||||
|
# How much we would need to scale the image to fit exactly in each tiling
|
||||||
|
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
|
||||||
|
|
||||||
|
# The original size can be zero in rare cases if the image is smaller than the margin
|
||||||
|
# In those cases letting the scale become infinite means the tiling is based on the
|
||||||
|
# other side, or falls back to the smallest tiling
|
||||||
|
with np.errstate(divide="ignore"):
|
||||||
|
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
|
||||||
|
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
||||||
|
if np.all(required_scale < 1):
|
||||||
|
# We are forced to downscale, so try to minimize the amount of downscaling
|
||||||
|
ix = np.argmax(required_scale)
|
||||||
|
else:
|
||||||
|
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
||||||
|
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
|
||||||
|
ix = np.argmin(required_scale)
|
||||||
|
return candidate_tilings[ix]
|
||||||
|
|
||||||
|
|
||||||
|
def build_resized_image(
|
||||||
|
image: np.ndarray,
|
||||||
|
base_image_input_size: list[int],
|
||||||
|
resample: PILImageResampling,
|
||||||
|
image_mean: list[float],
|
||||||
|
image_std: list[float],
|
||||||
|
image_patch_size: int,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
resized = resize_image(
|
||||||
|
image,
|
||||||
|
base_image_input_size,
|
||||||
|
resample,
|
||||||
|
)
|
||||||
|
resized = normalize_image(resized, image_mean, image_std)
|
||||||
|
if len(resized.shape) == 3:
|
||||||
|
resized = np.expand_dims(resized, 0)
|
||||||
|
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||||
|
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||||
|
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
|
||||||
|
return resized, resize_idx
|
||||||
|
|
||||||
|
|
||||||
|
def build_overlapping_crops(
|
||||||
|
image: np.ndarray,
|
||||||
|
max_crops: int,
|
||||||
|
overlap_margins: list[int],
|
||||||
|
base_image_input_size: list[int],
|
||||||
|
resample: PILImageResampling,
|
||||||
|
image_mean: list[float],
|
||||||
|
image_std: list[float],
|
||||||
|
image_patch_size: int,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Decompose an image into a set of overlapping crops
|
||||||
|
|
||||||
|
:return crop_arr: [n_crops, h, w, 3] The crops
|
||||||
|
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
|
||||||
|
the crops were extracted from, what patch in `crop_arr` it corresponds to
|
||||||
|
"""
|
||||||
|
original_image_h, original_image_w = image.shape[:2]
|
||||||
|
crop_size = base_image_input_size[0]
|
||||||
|
assert base_image_input_size[0] == base_image_input_size[1]
|
||||||
|
|
||||||
|
left_margin, right_margin = overlap_margins
|
||||||
|
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
|
||||||
|
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
|
||||||
|
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
||||||
|
crop_window_size = crop_window_patches * image_patch_size
|
||||||
|
crop_patch_w = base_image_input_size[1] // image_patch_size
|
||||||
|
crop_patch_h = base_image_input_size[0] // image_patch_size
|
||||||
|
original_image_h, original_image_w = image.shape[:2]
|
||||||
|
crop_size = base_image_input_size[0]
|
||||||
|
|
||||||
|
# Decide how to tile the image, to account for the overlap margins we compute the tiling
|
||||||
|
# as if we had an image without the margins and were using a crop size without the margins
|
||||||
|
tiling = select_tiling(
|
||||||
|
original_image_h - total_margin_pixels,
|
||||||
|
original_image_w - total_margin_pixels,
|
||||||
|
crop_window_size,
|
||||||
|
max_crops,
|
||||||
|
)
|
||||||
|
|
||||||
|
src = resize_image(
|
||||||
|
image,
|
||||||
|
[
|
||||||
|
tiling[0] * crop_window_size + total_margin_pixels,
|
||||||
|
tiling[1] * crop_window_size + total_margin_pixels,
|
||||||
|
],
|
||||||
|
resample,
|
||||||
|
)
|
||||||
|
src = normalize_image(src, image_mean, image_std)
|
||||||
|
|
||||||
|
# Now we have to split the image into crops, and track what patches came from
|
||||||
|
# where in `patch_idx_arr`
|
||||||
|
n_crops = tiling[0] * tiling[1]
|
||||||
|
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
|
||||||
|
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
|
||||||
|
on_crop = 0
|
||||||
|
for i in range(tiling[0]):
|
||||||
|
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
|
||||||
|
# which results in overlapping crop windows
|
||||||
|
y0 = i * crop_window_size
|
||||||
|
for j in range(tiling[1]):
|
||||||
|
x0 = j * crop_window_size
|
||||||
|
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
|
||||||
|
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
|
||||||
|
patch_idx += on_crop * crop_patch_h * crop_patch_w
|
||||||
|
|
||||||
|
# Mask out idx that are in the overlap region
|
||||||
|
if i != 0:
|
||||||
|
patch_idx[:left_margin, :] = -1
|
||||||
|
if j != 0:
|
||||||
|
patch_idx[:, :left_margin] = -1
|
||||||
|
if i != tiling[0] - 1:
|
||||||
|
patch_idx[-right_margin:, :] = -1
|
||||||
|
if j != tiling[1] - 1:
|
||||||
|
patch_idx[:, -right_margin:] = -1
|
||||||
|
patch_idx_arr[on_crop] = patch_idx
|
||||||
|
on_crop += 1
|
||||||
|
|
||||||
|
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
|
||||||
|
# so it is ordered left-to-right order
|
||||||
|
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
|
||||||
|
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
|
||||||
|
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
|
||||||
|
|
||||||
|
# Now get the parts not in the overlap region, so it should map each patch in `src`
|
||||||
|
# to the correct patch it should come from in `crop_arr`
|
||||||
|
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
|
||||||
|
src.shape[0] // image_patch_size,
|
||||||
|
src.shape[1] // image_patch_size,
|
||||||
|
)
|
||||||
|
return crop_arr, patch_idx_arr
|
||||||
|
|
||||||
|
|
||||||
|
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
|
||||||
|
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
|
||||||
|
if len(array.shape) == 3:
|
||||||
|
n_crops, h, w = array.shape
|
||||||
|
h_patches = h // patch_size
|
||||||
|
w_patches = w // patch_size
|
||||||
|
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
|
||||||
|
array = np.transpose(array, [0, 1, 3, 2, 4])
|
||||||
|
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
|
||||||
|
return array
|
||||||
|
else:
|
||||||
|
n_crops, h, w, c = array.shape
|
||||||
|
h_patches = h // patch_size
|
||||||
|
w_patches = w // patch_size
|
||||||
|
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
|
||||||
|
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
|
||||||
|
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
def arange_for_pooling(
|
||||||
|
idx_arr: np.ndarray,
|
||||||
|
pool_h: int,
|
||||||
|
pool_w: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
|
||||||
|
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
|
||||||
|
idx_arr = np.pad(
|
||||||
|
idx_arr,
|
||||||
|
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
|
||||||
|
mode="constant",
|
||||||
|
constant_values=-1,
|
||||||
|
)
|
||||||
|
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_patches_and_grids(
|
||||||
|
image: np.ndarray,
|
||||||
|
max_crops: int,
|
||||||
|
overlap_margins: list[int],
|
||||||
|
base_image_input_size: list[int],
|
||||||
|
resample: PILImageResampling,
|
||||||
|
image_mean: list[float],
|
||||||
|
image_std: list[float],
|
||||||
|
image_patch_size: int,
|
||||||
|
image_pooling_w: int,
|
||||||
|
image_pooling_h: int,
|
||||||
|
crop_mode: str = "overlap-and-resize-c2",
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
:return image_grids, the shape of each (low-res, high-res) image after pooling
|
||||||
|
:return crops, the image crops to processes with the ViT
|
||||||
|
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
|
||||||
|
patches in `crops` to pool for that token, masked with -1
|
||||||
|
"""
|
||||||
|
if isinstance(base_image_input_size, int):
|
||||||
|
base_image_input_size = (base_image_input_size, base_image_input_size)
|
||||||
|
|
||||||
|
base_image_input_d = image_patch_size
|
||||||
|
pooling_w = image_pooling_w
|
||||||
|
pooling_h = image_pooling_h
|
||||||
|
crop_patch_w = base_image_input_size[1] // base_image_input_d
|
||||||
|
crop_patch_h = base_image_input_size[0] // base_image_input_d
|
||||||
|
|
||||||
|
if crop_mode == "resize":
|
||||||
|
resized, resize_idx = build_resized_image(
|
||||||
|
image,
|
||||||
|
base_image_input_size,
|
||||||
|
resample,
|
||||||
|
image_mean,
|
||||||
|
image_std,
|
||||||
|
image_patch_size,
|
||||||
|
)
|
||||||
|
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||||
|
resized_h, resized_w = resize_idx.shape[:2]
|
||||||
|
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||||
|
image_grid = [np.array([resized_h, resized_w, 0, 0])]
|
||||||
|
return (
|
||||||
|
np.stack(image_grid, 0),
|
||||||
|
batch_pixels_to_patches(resized, image_patch_size),
|
||||||
|
resize_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
|
||||||
|
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
|
||||||
|
|
||||||
|
crop_arr, patch_idx_arr = build_overlapping_crops(
|
||||||
|
image,
|
||||||
|
max_crops,
|
||||||
|
overlap_margins,
|
||||||
|
base_image_input_size,
|
||||||
|
resample,
|
||||||
|
image_mean,
|
||||||
|
image_std,
|
||||||
|
image_patch_size,
|
||||||
|
)
|
||||||
|
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
|
||||||
|
h, w = pooling_idx.shape[:2]
|
||||||
|
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
|
||||||
|
|
||||||
|
# Finally do the same for the global image
|
||||||
|
resized, resize_idx = build_resized_image(
|
||||||
|
image,
|
||||||
|
base_image_input_size,
|
||||||
|
resample,
|
||||||
|
image_mean,
|
||||||
|
image_std,
|
||||||
|
image_patch_size,
|
||||||
|
)
|
||||||
|
crop_arr = np.concatenate([resized, crop_arr], 0)
|
||||||
|
|
||||||
|
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
|
||||||
|
resized_h, resized_w = resize_idx.shape[:2]
|
||||||
|
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
|
||||||
|
|
||||||
|
# Global image goes first, so the order of patches in previous crops gets increased
|
||||||
|
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
|
||||||
|
pooling_idx = np.concatenate([resize_idx, pooling_idx])
|
||||||
|
image_grid = [np.array([resized_h, resized_w, h, w])]
|
||||||
|
|
||||||
|
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
|
||||||
|
max_crops: int | None
|
||||||
|
overlap_margins: list[int] | None
|
||||||
|
crop_mode: str | None
|
||||||
|
patch_size: int | None
|
||||||
|
pooling_size: list[int] | None
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2ImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a MolmoAct2 image processor that preprocesses images for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
|
||||||
|
Size of the image after resizing.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||||
|
Resampling filter to use when resizing the image.
|
||||||
|
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||||
|
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||||
|
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
||||||
|
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
max_crops (`int`, *optional*, defaults to `8`):
|
||||||
|
Maximum number of crops to use per image.
|
||||||
|
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
|
||||||
|
Overlap margins to use.
|
||||||
|
patch_size (`int`, *optional*, defaults to 14):
|
||||||
|
The spatial patch size of the vision encoder.
|
||||||
|
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
|
||||||
|
The pooling size of the vision adapter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: dict[str, int] | None = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
image_mean: float | list[float] | None = None,
|
||||||
|
image_std: float | list[float] | None = None,
|
||||||
|
do_convert_rgb: bool = True,
|
||||||
|
max_crops: int = 8,
|
||||||
|
overlap_margins: list[int] = [4, 4],
|
||||||
|
crop_mode: str = "overlap-and-resize-c2",
|
||||||
|
patch_size: int = 14,
|
||||||
|
pooling_size: list[int] = [2, 2],
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 378, "width": 378}
|
||||||
|
size = get_size_dict(size, default_to_square=True)
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
self.resample = resample
|
||||||
|
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||||
|
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||||
|
self.do_convert_rgb = do_convert_rgb
|
||||||
|
|
||||||
|
self.max_crops = max_crops
|
||||||
|
self.overlap_margins = overlap_margins
|
||||||
|
self.crop_mode = crop_mode
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.pooling_size = pooling_size
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
size: dict[str, int] | None = None,
|
||||||
|
resample: PILImageResampling | None = None,
|
||||||
|
image_mean: float | list[float] | None = None,
|
||||||
|
image_std: float | list[float] | None = None,
|
||||||
|
do_convert_rgb: bool | None = None,
|
||||||
|
max_crops: int | None = None,
|
||||||
|
overlap_margins: list[int] | None = None,
|
||||||
|
crop_mode: str | None = None,
|
||||||
|
patch_size: int | None = None,
|
||||||
|
pooling_size: list[int] | None = None,
|
||||||
|
return_tensors: str | TensorType | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image to preprocess.
|
||||||
|
size (`dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the image after resizing.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
|
||||||
|
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||||
|
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
|
||||||
|
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||||
|
`True`.
|
||||||
|
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||||
|
Whether to convert the image to RGB.
|
||||||
|
max_crops (`int`, *optional*, defaults to `self.max_crops`):
|
||||||
|
Maximum number of crops to use per image.
|
||||||
|
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
|
||||||
|
Overlap margins to use.
|
||||||
|
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||||
|
The spatial patch size of the vision encoder.
|
||||||
|
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
|
||||||
|
The pooling size of the vision adapter.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `BatchFeature` containing the following keys:
|
||||||
|
- `pixel_values`: The preprocessed images.
|
||||||
|
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
|
||||||
|
- `image_grids`: The image grids.
|
||||||
|
- `image_num_crops`: The number of crops for each image.
|
||||||
|
"""
|
||||||
|
if size is not None:
|
||||||
|
if "height" not in size or "width" not in size:
|
||||||
|
raise ValueError("size must contain 'height' and 'width' keys.")
|
||||||
|
else:
|
||||||
|
size = {**self.size}
|
||||||
|
|
||||||
|
base_image_input_size = [size["height"], size["width"]]
|
||||||
|
|
||||||
|
resample = resample or self.resample
|
||||||
|
image_mean = image_mean or self.image_mean
|
||||||
|
image_std = image_std or self.image_std
|
||||||
|
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
|
||||||
|
|
||||||
|
max_crops = max_crops or self.max_crops
|
||||||
|
overlap_margins = overlap_margins or self.overlap_margins
|
||||||
|
crop_mode = crop_mode or self.crop_mode
|
||||||
|
patch_size = patch_size or self.patch_size
|
||||||
|
pooling_size = pooling_size or self.pooling_size
|
||||||
|
|
||||||
|
image_pooling_h, image_pooling_w = pooling_size
|
||||||
|
|
||||||
|
if images is not None:
|
||||||
|
images = self.fetch_images(images)
|
||||||
|
images = make_flat_list_of_images(images)
|
||||||
|
|
||||||
|
if images is not None and not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_convert_rgb:
|
||||||
|
images = [convert_to_rgb(image) for image in images]
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if images is not None:
|
||||||
|
batch_grids = []
|
||||||
|
batch_crops = []
|
||||||
|
batch_pooled_patches_idx = []
|
||||||
|
batch_num_crops = []
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
image_grid, crops, pooled_idx = image_to_patches_and_grids(
|
||||||
|
image,
|
||||||
|
max_crops,
|
||||||
|
overlap_margins,
|
||||||
|
base_image_input_size,
|
||||||
|
resample,
|
||||||
|
image_mean,
|
||||||
|
image_std,
|
||||||
|
patch_size,
|
||||||
|
image_pooling_w,
|
||||||
|
image_pooling_h,
|
||||||
|
crop_mode,
|
||||||
|
)
|
||||||
|
batch_grids.append(image_grid)
|
||||||
|
batch_crops.append(crops)
|
||||||
|
batch_pooled_patches_idx.append(pooled_idx)
|
||||||
|
batch_num_crops.append(crops.shape[0])
|
||||||
|
|
||||||
|
pixel_values = np.concatenate(batch_crops, 0)
|
||||||
|
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
|
||||||
|
image_grids = np.concatenate(batch_grids, 0)
|
||||||
|
image_num_crops = np.array(batch_num_crops)
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_token_pooling=image_token_pooling,
|
||||||
|
image_grids=image_grids,
|
||||||
|
image_num_crops=image_num_crops,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatchFeature(data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
MolmoAct2ImageProcessor.register_for_auto_class()
|
||||||
@@ -0,0 +1,748 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
|
||||||
|
"""Inference utilities for MolmoAct2"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ActionFlowInputs:
|
||||||
|
trajectory: torch.Tensor
|
||||||
|
context: Any
|
||||||
|
modulations: Sequence[Any]
|
||||||
|
action_dim_is_pad: torch.Tensor | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ActionFlowCudaGraph:
|
||||||
|
key: tuple[Any, ...]
|
||||||
|
graph: torch.cuda.CUDAGraph
|
||||||
|
static_inputs: _ActionFlowInputs
|
||||||
|
output: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DepthDecodeCudaGraphLayerStage:
|
||||||
|
residual: torch.Tensor
|
||||||
|
query: torch.Tensor
|
||||||
|
key: torch.Tensor
|
||||||
|
value: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DepthDecodeCudaGraphPostStage:
|
||||||
|
graph: torch.cuda.CUDAGraph
|
||||||
|
attn_context: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DepthDecodeCudaGraph:
|
||||||
|
cache_key: tuple[Any, ...]
|
||||||
|
pre_graph: torch.cuda.CUDAGraph
|
||||||
|
token_ids: torch.Tensor
|
||||||
|
cos: torch.Tensor
|
||||||
|
sin: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
|
||||||
|
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
|
||||||
|
output: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DepthDecodeCudaGraphSpec:
|
||||||
|
eligible: bool
|
||||||
|
cache_key_prefix: tuple[Any, ...]
|
||||||
|
num_hidden_layers: int
|
||||||
|
head_dim: int
|
||||||
|
num_attention_heads: int
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
|
||||||
|
if past_key_values is None:
|
||||||
|
return 0
|
||||||
|
seq_len = past_key_values.get_seq_length()
|
||||||
|
if torch.is_tensor(seq_len):
|
||||||
|
return int(seq_len.item())
|
||||||
|
return int(seq_len)
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_max_len_int(past_key_values: Cache | None) -> int:
|
||||||
|
if past_key_values is None:
|
||||||
|
return -1
|
||||||
|
max_len = past_key_values.get_max_cache_shape()
|
||||||
|
if torch.is_tensor(max_len):
|
||||||
|
return int(max_len.item())
|
||||||
|
return int(max_len)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_cache_key_values(
|
||||||
|
past_key_values: Cache,
|
||||||
|
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
|
||||||
|
layers = getattr(past_key_values, "layers", None)
|
||||||
|
if layers is not None:
|
||||||
|
for layer in layers:
|
||||||
|
yield getattr(layer, "keys", None), getattr(layer, "values", None)
|
||||||
|
return
|
||||||
|
for layer in past_key_values:
|
||||||
|
yield layer[0], layer[1]
|
||||||
|
|
||||||
|
|
||||||
|
class _DepthDecodeStaticLayerCache:
|
||||||
|
is_compileable = False
|
||||||
|
is_sliding = False
|
||||||
|
|
||||||
|
def __init__(self, max_cache_len: int) -> None:
|
||||||
|
self.max_cache_len = int(max_cache_len)
|
||||||
|
self.cumulative_length = 0
|
||||||
|
self.keys: torch.Tensor | None = None
|
||||||
|
self.values: torch.Tensor | None = None
|
||||||
|
|
||||||
|
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
||||||
|
bsz, n_heads = key_states.shape[:2]
|
||||||
|
self.keys = torch.empty(
|
||||||
|
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
|
||||||
|
dtype=key_states.dtype,
|
||||||
|
device=key_states.device,
|
||||||
|
)
|
||||||
|
self.values = torch.empty(
|
||||||
|
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
|
||||||
|
dtype=value_states.dtype,
|
||||||
|
device=value_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.keys is None:
|
||||||
|
self._allocate(key_states, value_states)
|
||||||
|
start = self.cumulative_length
|
||||||
|
end = start + key_states.shape[-2]
|
||||||
|
if end > self.max_cache_len:
|
||||||
|
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
|
||||||
|
self.keys[:, :, start:end, :].copy_(key_states)
|
||||||
|
self.values[:, :, start:end, :].copy_(value_states)
|
||||||
|
self.cumulative_length = end
|
||||||
|
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
|
||||||
|
|
||||||
|
def get_seq_length(self) -> int:
|
||||||
|
return self.cumulative_length
|
||||||
|
|
||||||
|
def get_max_cache_shape(self) -> int:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.cumulative_length = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _DepthDecodeStaticCache(Cache):
|
||||||
|
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
|
||||||
|
text_config = config.get_text_config(decoder=True)
|
||||||
|
super().__init__(
|
||||||
|
layers=[
|
||||||
|
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
|
||||||
|
for _ in range(text_config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: int = 0) -> int:
|
||||||
|
return self.layers[layer_idx].get_seq_length()
|
||||||
|
|
||||||
|
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
||||||
|
return self.layers[layer_idx].get_max_cache_shape()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
for layer in self.layers:
|
||||||
|
layer.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class ActionCudaGraphManager:
|
||||||
|
def __init__(self, model: Any) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.enabled = True
|
||||||
|
self.action_flow_graph: _ActionFlowCudaGraph | None = None
|
||||||
|
|
||||||
|
def set_enabled(self, enabled: bool) -> None:
|
||||||
|
self.enabled = bool(enabled)
|
||||||
|
|
||||||
|
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
|
||||||
|
action_model = self.model
|
||||||
|
if not self.enabled:
|
||||||
|
return False
|
||||||
|
if action_model.training or action_model._require_action_expert().training:
|
||||||
|
return False
|
||||||
|
if inputs.trajectory.device.type != "cuda":
|
||||||
|
return False
|
||||||
|
|
||||||
|
def all_on_cuda():
|
||||||
|
yield inputs.trajectory
|
||||||
|
for k, v in inputs.context.kv_contexts:
|
||||||
|
yield k
|
||||||
|
yield v
|
||||||
|
for t in (
|
||||||
|
inputs.context.cross_mask,
|
||||||
|
inputs.context.self_mask,
|
||||||
|
inputs.context.valid_action,
|
||||||
|
inputs.action_dim_is_pad,
|
||||||
|
):
|
||||||
|
if t is not None:
|
||||||
|
yield t
|
||||||
|
if inputs.context.rope_cache is not None:
|
||||||
|
yield from inputs.context.rope_cache
|
||||||
|
for step in inputs.modulations:
|
||||||
|
yield step.conditioning
|
||||||
|
for block_modulation in step.block_modulations:
|
||||||
|
yield from block_modulation
|
||||||
|
yield from step.final_modulation
|
||||||
|
|
||||||
|
return all(t.device.type == "cuda" for t in all_on_cuda())
|
||||||
|
|
||||||
|
def run_action_flow(
|
||||||
|
self,
|
||||||
|
inputs: _ActionFlowInputs,
|
||||||
|
steps: int,
|
||||||
|
run_loop,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
key = _cuda_graph_key(inputs, steps)
|
||||||
|
cache = self.action_flow_graph
|
||||||
|
if cache is None or cache.key != key:
|
||||||
|
static_inputs = _clone_static_inputs(inputs)
|
||||||
|
graph, output = _capture_cuda_graph(
|
||||||
|
lambda: run_loop(static_inputs, steps),
|
||||||
|
inputs.trajectory.device,
|
||||||
|
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
|
||||||
|
)
|
||||||
|
cache = _ActionFlowCudaGraph(
|
||||||
|
key=key,
|
||||||
|
graph=graph,
|
||||||
|
static_inputs=static_inputs,
|
||||||
|
output=output,
|
||||||
|
)
|
||||||
|
self.action_flow_graph = cache
|
||||||
|
else:
|
||||||
|
_copy_inputs_(cache.static_inputs, inputs)
|
||||||
|
|
||||||
|
cache.graph.replay()
|
||||||
|
return cache.output.clone()
|
||||||
|
|
||||||
|
|
||||||
|
class DepthDecodeCudaGraphManager:
|
||||||
|
def __init__(self, model: Any) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.backbone = model.model
|
||||||
|
self.enabled = True
|
||||||
|
self.graph: _DepthDecodeCudaGraph | None = None
|
||||||
|
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
|
||||||
|
|
||||||
|
def set_enabled(self, enabled: bool) -> None:
|
||||||
|
self.enabled = bool(enabled)
|
||||||
|
|
||||||
|
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
|
||||||
|
return _DepthDecodeStaticCache(
|
||||||
|
config=self.model.config.text_config,
|
||||||
|
max_cache_len=max_cache_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
|
||||||
|
static = self.graph_spec
|
||||||
|
if static is None:
|
||||||
|
cfg = self.backbone.transformer.config
|
||||||
|
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
|
||||||
|
static = _DepthDecodeCudaGraphSpec(
|
||||||
|
eligible=(
|
||||||
|
not cfg.norm_after
|
||||||
|
and cfg.rope_scaling_layers is None
|
||||||
|
and getattr(rotary_emb, "rope_type", None) == "default"
|
||||||
|
and cfg._attn_implementation == "sdpa"
|
||||||
|
),
|
||||||
|
cache_key_prefix=(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.num_attention_heads,
|
||||||
|
cfg.num_key_value_heads,
|
||||||
|
cfg.head_dim,
|
||||||
|
cfg.num_hidden_layers,
|
||||||
|
cfg.use_qk_norm,
|
||||||
|
cfg.qk_norm_type,
|
||||||
|
cfg._attn_implementation,
|
||||||
|
),
|
||||||
|
num_hidden_layers=cfg.num_hidden_layers,
|
||||||
|
head_dim=cfg.head_dim,
|
||||||
|
num_attention_heads=cfg.num_attention_heads,
|
||||||
|
)
|
||||||
|
self.graph_spec = static
|
||||||
|
return static
|
||||||
|
|
||||||
|
def can_use(
|
||||||
|
self,
|
||||||
|
next_input_ids: torch.Tensor,
|
||||||
|
*,
|
||||||
|
past_key_values: Cache,
|
||||||
|
attention_bias: torch.Tensor,
|
||||||
|
) -> bool:
|
||||||
|
if not self.enabled or self.model.training or self.backbone.transformer.training:
|
||||||
|
return False
|
||||||
|
if next_input_ids.device.type != "cuda":
|
||||||
|
return False
|
||||||
|
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
|
||||||
|
return False
|
||||||
|
if not isinstance(past_key_values, _DepthDecodeStaticCache):
|
||||||
|
return False
|
||||||
|
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
|
||||||
|
return False
|
||||||
|
return self._depth_decode_spec().eligible
|
||||||
|
|
||||||
|
def _depth_decode_key(
|
||||||
|
self,
|
||||||
|
next_input_ids: torch.Tensor,
|
||||||
|
attention_bias: torch.Tensor,
|
||||||
|
) -> tuple[Any, ...]:
|
||||||
|
device = next_input_ids.device
|
||||||
|
return (
|
||||||
|
self._depth_decode_spec().cache_key_prefix,
|
||||||
|
device.type,
|
||||||
|
device.index,
|
||||||
|
self.model.lm_head.weight.dtype,
|
||||||
|
attention_bias.shape[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
|
||||||
|
emb = self.backbone.transformer.rotary_emb
|
||||||
|
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
|
||||||
|
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
|
||||||
|
|
||||||
|
def _depth_decode_pre_layer(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
block = self.backbone.transformer.blocks[layer_idx]
|
||||||
|
attention = block.self_attn
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = block.attn_norm(hidden_states)
|
||||||
|
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, attention.head_dim)
|
||||||
|
qkv = attention.att_proj(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
|
||||||
|
value_states = value_states.view(hidden_shape)
|
||||||
|
|
||||||
|
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
|
||||||
|
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
|
||||||
|
|
||||||
|
if apply_qk_norm and not norm_after_view:
|
||||||
|
query_states = attention.q_norm(query_states)
|
||||||
|
key_states = attention.k_norm(key_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(hidden_shape)
|
||||||
|
key_states = key_states.view(hidden_shape)
|
||||||
|
|
||||||
|
if norm_after_view:
|
||||||
|
query_states = attention.q_norm(query_states)
|
||||||
|
key_states = attention.k_norm(key_states)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
return residual, query_states, key_states, value_states
|
||||||
|
|
||||||
|
def _depth_decode_pre0(
|
||||||
|
self,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
inputs_embeds = self.model._embed_base_tokens(token_ids)
|
||||||
|
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
|
||||||
|
|
||||||
|
def _depth_decode_post_layer(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
attn_context: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
block = self.backbone.transformer.blocks[layer_idx]
|
||||||
|
attention = block.self_attn
|
||||||
|
input_shape = residual.shape[:-1]
|
||||||
|
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
|
||||||
|
attn_output = attention.attn_out(attn_output)
|
||||||
|
hidden_states = residual + block.dropout(attn_output)
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = block.ff_norm(hidden_states)
|
||||||
|
hidden_states = block.mlp(hidden_states)
|
||||||
|
hidden_states = residual + block.dropout(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _depth_decode_post_and_pre_next(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
attn_context: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||||
|
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
|
||||||
|
|
||||||
|
def _depth_decode_last_post(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
attn_context: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
|
||||||
|
return self.backbone.transformer.ln_f(hidden_states)
|
||||||
|
|
||||||
|
def _build_depth_decode_graph(
|
||||||
|
self,
|
||||||
|
next_input_ids: torch.Tensor,
|
||||||
|
*,
|
||||||
|
past_length: int,
|
||||||
|
attention_bias: torch.Tensor,
|
||||||
|
) -> _DepthDecodeCudaGraph:
|
||||||
|
text_config = self.backbone.transformer.config
|
||||||
|
device = next_input_ids.device
|
||||||
|
dtype = self.model.lm_head.weight.dtype
|
||||||
|
static = self._depth_decode_spec()
|
||||||
|
num_layers = static.num_hidden_layers
|
||||||
|
head_dim = static.head_dim
|
||||||
|
max_cache_len = int(attention_bias.shape[-1])
|
||||||
|
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
|
||||||
|
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
|
||||||
|
|
||||||
|
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
|
||||||
|
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
|
||||||
|
sin = torch.empty_like(cos)
|
||||||
|
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
|
||||||
|
context_shape = (1, 1, static.num_attention_heads, head_dim)
|
||||||
|
|
||||||
|
token_ids.copy_(next_input_ids)
|
||||||
|
self._select_depth_decode_rope(cos, sin, past_length=past_length)
|
||||||
|
|
||||||
|
pre_graph, pre_output = _capture_cuda_graph(
|
||||||
|
lambda: self._depth_decode_pre0(token_ids, cos, sin),
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
|
||||||
|
post_graphs = []
|
||||||
|
for layer_idx in range(num_layers - 1):
|
||||||
|
stage = stages[-1]
|
||||||
|
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||||
|
graph, output = _capture_cuda_graph(
|
||||||
|
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
|
||||||
|
self._depth_decode_post_and_pre_next(
|
||||||
|
layer_idx,
|
||||||
|
stage.residual,
|
||||||
|
attn_context,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
|
||||||
|
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
|
||||||
|
|
||||||
|
last_stage = stages[-1]
|
||||||
|
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
|
||||||
|
last_graph, last_output = _capture_cuda_graph(
|
||||||
|
lambda: self._depth_decode_last_post(
|
||||||
|
num_layers - 1,
|
||||||
|
last_stage.residual,
|
||||||
|
last_attn_context,
|
||||||
|
),
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
|
||||||
|
return _DepthDecodeCudaGraph(
|
||||||
|
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
|
||||||
|
pre_graph=pre_graph,
|
||||||
|
token_ids=token_ids,
|
||||||
|
cos=cos,
|
||||||
|
sin=sin,
|
||||||
|
positions=positions,
|
||||||
|
stages=tuple(stages),
|
||||||
|
post_graphs=tuple(post_graphs),
|
||||||
|
output=last_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_depth_decode_graph(
|
||||||
|
self,
|
||||||
|
next_input_ids: torch.Tensor,
|
||||||
|
*,
|
||||||
|
past_length: int,
|
||||||
|
attention_bias: torch.Tensor,
|
||||||
|
) -> _DepthDecodeCudaGraph:
|
||||||
|
key = self._depth_decode_key(next_input_ids, attention_bias)
|
||||||
|
decode_graph = self.graph
|
||||||
|
if decode_graph is None or decode_graph.cache_key != key:
|
||||||
|
decode_graph = self._build_depth_decode_graph(
|
||||||
|
next_input_ids,
|
||||||
|
past_length=past_length,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
)
|
||||||
|
self.graph = decode_graph
|
||||||
|
else:
|
||||||
|
decode_graph.token_ids.copy_(next_input_ids)
|
||||||
|
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
|
||||||
|
return decode_graph
|
||||||
|
|
||||||
|
def _run_depth_decode_attention_core(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
stage: _DepthDecodeCudaGraphLayerStage,
|
||||||
|
*,
|
||||||
|
past_key_values: Cache,
|
||||||
|
attention_bias: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
attention = self.backbone.transformer.blocks[layer_idx].self_attn
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_values.update(
|
||||||
|
stage.key,
|
||||||
|
stage.value,
|
||||||
|
layer_idx,
|
||||||
|
cache_kwargs,
|
||||||
|
)
|
||||||
|
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
|
||||||
|
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
stage.query,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_bias,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
return attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
next_input_ids: torch.Tensor,
|
||||||
|
*,
|
||||||
|
past_key_values: Cache,
|
||||||
|
attention_bias: torch.Tensor,
|
||||||
|
past_length: int,
|
||||||
|
) -> tuple[torch.Tensor, Cache]:
|
||||||
|
end = past_length + 1
|
||||||
|
decode_graph = self._get_depth_decode_graph(
|
||||||
|
next_input_ids,
|
||||||
|
past_length=past_length,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
)
|
||||||
|
cache_position = decode_graph.positions[past_length:end]
|
||||||
|
attention_bias_q = attention_bias[:, :, past_length:end, :end]
|
||||||
|
|
||||||
|
decode_graph.pre_graph.replay()
|
||||||
|
|
||||||
|
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
|
||||||
|
attn_context = self._run_depth_decode_attention_core(
|
||||||
|
layer_idx,
|
||||||
|
decode_graph.stages[layer_idx],
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_bias=attention_bias_q,
|
||||||
|
cache_position=cache_position,
|
||||||
|
cos=decode_graph.cos,
|
||||||
|
sin=decode_graph.sin,
|
||||||
|
)
|
||||||
|
post_graph.attn_context.copy_(attn_context)
|
||||||
|
post_graph.graph.replay()
|
||||||
|
|
||||||
|
return decode_graph.output, past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
def _cuda_graph_tensor_signature(
|
||||||
|
tensor: torch.Tensor | None,
|
||||||
|
) -> tuple[Any, ...] | None:
|
||||||
|
if tensor is None:
|
||||||
|
return None
|
||||||
|
return (
|
||||||
|
tuple(tensor.shape),
|
||||||
|
tuple(tensor.stride()),
|
||||||
|
str(tensor.dtype),
|
||||||
|
str(tensor.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
|
||||||
|
sig = _cuda_graph_tensor_signature
|
||||||
|
return (
|
||||||
|
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
|
||||||
|
sig(context.cross_mask),
|
||||||
|
sig(context.self_mask),
|
||||||
|
sig(context.valid_action),
|
||||||
|
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
|
||||||
|
sig = _cuda_graph_tensor_signature
|
||||||
|
return tuple(
|
||||||
|
(
|
||||||
|
sig(step.conditioning),
|
||||||
|
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
|
||||||
|
tuple(sig(t) for t in step.final_modulation),
|
||||||
|
)
|
||||||
|
for step in modulations
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
|
||||||
|
sig = _cuda_graph_tensor_signature
|
||||||
|
return (
|
||||||
|
sig(inputs.trajectory),
|
||||||
|
_cuda_graph_context_signature(inputs.context),
|
||||||
|
_cuda_graph_modulation_signature(inputs.modulations),
|
||||||
|
sig(inputs.action_dim_is_pad),
|
||||||
|
int(steps),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
|
||||||
|
if tensor is None:
|
||||||
|
return None
|
||||||
|
static = torch.empty_strided(
|
||||||
|
tuple(tensor.shape),
|
||||||
|
tuple(tensor.stride()),
|
||||||
|
device=tensor.device,
|
||||||
|
dtype=tensor.dtype,
|
||||||
|
)
|
||||||
|
static.copy_(tensor)
|
||||||
|
return static
|
||||||
|
|
||||||
|
|
||||||
|
def _clone_static_context(context: Any) -> Any:
|
||||||
|
rope_cache = None
|
||||||
|
if context.rope_cache is not None:
|
||||||
|
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
|
||||||
|
return context.__class__(
|
||||||
|
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
|
||||||
|
cross_mask=_clone_static_tensor(context.cross_mask),
|
||||||
|
self_mask=_clone_static_tensor(context.self_mask),
|
||||||
|
valid_action=_clone_static_tensor(context.valid_action),
|
||||||
|
rope_cache=rope_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
|
||||||
|
return tuple(
|
||||||
|
step.__class__(
|
||||||
|
conditioning=_clone_static_tensor(step.conditioning),
|
||||||
|
block_modulations=tuple(
|
||||||
|
tuple(_clone_static_tensor(t) for t in block_modulation)
|
||||||
|
for block_modulation in step.block_modulations
|
||||||
|
),
|
||||||
|
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
|
||||||
|
)
|
||||||
|
for step in modulations
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
|
||||||
|
return _ActionFlowInputs(
|
||||||
|
trajectory=_clone_static_tensor(inputs.trajectory),
|
||||||
|
context=_clone_static_context(inputs.context),
|
||||||
|
modulations=_clone_static_modulations(inputs.modulations),
|
||||||
|
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_context_(dst: Any, src: Any) -> None:
|
||||||
|
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
|
||||||
|
dst_k.copy_(src_k)
|
||||||
|
dst_v.copy_(src_v)
|
||||||
|
if src.cross_mask is not None:
|
||||||
|
dst.cross_mask.copy_(src.cross_mask)
|
||||||
|
if src.self_mask is not None:
|
||||||
|
dst.self_mask.copy_(src.self_mask)
|
||||||
|
if src.valid_action is not None:
|
||||||
|
dst.valid_action.copy_(src.valid_action)
|
||||||
|
if src.rope_cache is not None:
|
||||||
|
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
|
||||||
|
dst_tensor.copy_(src_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
|
||||||
|
dst.trajectory.copy_(src.trajectory)
|
||||||
|
_copy_context_(dst.context, src.context)
|
||||||
|
if src.action_dim_is_pad is not None:
|
||||||
|
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
|
||||||
|
|
||||||
|
|
||||||
|
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rotary_pos_emb(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
unsqueeze_dim: int = 1,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def _capture_cuda_graph(
|
||||||
|
fn,
|
||||||
|
device: torch.device,
|
||||||
|
*,
|
||||||
|
after_warmup=None,
|
||||||
|
) -> tuple[torch.cuda.CUDAGraph, Any]:
|
||||||
|
warmup_stream = torch.cuda.Stream(device=device)
|
||||||
|
warmup_stream.wait_stream(torch.cuda.current_stream(device))
|
||||||
|
with torch.cuda.stream(warmup_stream):
|
||||||
|
fn()
|
||||||
|
torch.cuda.current_stream(device).wait_stream(warmup_stream)
|
||||||
|
if after_warmup is not None:
|
||||||
|
after_warmup()
|
||||||
|
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
output = fn()
|
||||||
|
return graph, output
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,431 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
|
||||||
|
"""
|
||||||
|
Processor class for MolmoAct2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.image_utils import ImageInput
|
||||||
|
from transformers.video_utils import VideoInput
|
||||||
|
from transformers.processing_utils import (
|
||||||
|
Unpack,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
)
|
||||||
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
|
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
|
||||||
|
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
|
||||||
|
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
|
||||||
|
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
|
||||||
|
IM_START_TOKEN = f"<im_start>"
|
||||||
|
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
|
||||||
|
FRAME_START_TOKEN = f"<frame_start>"
|
||||||
|
IM_END_TOKEN = f"<im_end>"
|
||||||
|
FRAME_END_TOKEN = f"<frame_end>"
|
||||||
|
IM_COL_TOKEN = f"<im_col>"
|
||||||
|
IMAGE_PROMPT = "<|image|>"
|
||||||
|
VIDEO_PROMPT = "<|video|>"
|
||||||
|
|
||||||
|
IMAGE_TOKENS = [
|
||||||
|
IMAGE_PATCH_TOKEN,
|
||||||
|
IM_COL_TOKEN,
|
||||||
|
IM_START_TOKEN,
|
||||||
|
LOW_RES_IMAGE_START_TOKEN,
|
||||||
|
FRAME_START_TOKEN,
|
||||||
|
IM_END_TOKEN,
|
||||||
|
FRAME_END_TOKEN,
|
||||||
|
IMAGE_LOW_RES_TOKEN,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
"""MolmoAct2 processor kwargs"""
|
||||||
|
|
||||||
|
images_kwargs: MolmoAct2ImagesKwargs
|
||||||
|
videos_kwargs: MolmoAct2VideoProcessorKwargs
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": True,
|
||||||
|
},
|
||||||
|
"videos_kwargs": {"return_metadata": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MolmoAct2Processor(ProcessorMixin):
|
||||||
|
attributes = ["image_processor", "video_processor", "tokenizer"]
|
||||||
|
optional_attributes = [
|
||||||
|
"chat_template",
|
||||||
|
"time_mode",
|
||||||
|
"image_use_col_tokens",
|
||||||
|
"use_single_crop_col_tokens",
|
||||||
|
"use_single_crop_start_token",
|
||||||
|
"video_use_col_tokens",
|
||||||
|
"use_frame_special_tokens",
|
||||||
|
]
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
|
video_processor_class = "AutoVideoProcessor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor: MolmoAct2ImageProcessor = None,
|
||||||
|
video_processor: MolmoAct2VideoProcessor = None,
|
||||||
|
tokenizer: AutoTokenizer = None,
|
||||||
|
chat_template: str | None = None,
|
||||||
|
image_use_col_tokens: bool | None = True,
|
||||||
|
use_single_crop_col_tokens: bool | None = None,
|
||||||
|
use_single_crop_start_token: bool | None = True,
|
||||||
|
video_use_col_tokens: bool | None = False,
|
||||||
|
use_frame_special_tokens: bool | None = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
image_processor,
|
||||||
|
video_processor,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
)
|
||||||
|
self.image_use_col_tokens = image_use_col_tokens
|
||||||
|
self.use_single_crop_col_tokens = use_single_crop_col_tokens
|
||||||
|
self.use_single_crop_start_token = use_single_crop_start_token
|
||||||
|
self.video_use_col_tokens = video_use_col_tokens
|
||||||
|
self.use_frame_special_tokens = use_frame_special_tokens
|
||||||
|
|
||||||
|
self.image_placeholder_token = IMAGE_PROMPT
|
||||||
|
self.video_placeholder_token = VIDEO_PROMPT
|
||||||
|
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
|
||||||
|
|
||||||
|
def get_image_tokens(self, image_grid: np.ndarray):
|
||||||
|
resized_h, resized_w, height, width = image_grid
|
||||||
|
if int(height) == 0 or int(width) == 0:
|
||||||
|
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||||
|
use_single_crop_col_tokens = (
|
||||||
|
self.image_use_col_tokens
|
||||||
|
if self.use_single_crop_col_tokens is None
|
||||||
|
else self.use_single_crop_col_tokens
|
||||||
|
)
|
||||||
|
if use_single_crop_col_tokens:
|
||||||
|
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||||
|
joint = [
|
||||||
|
[IM_START_TOKEN],
|
||||||
|
np.tile(per_row, [resized_h]),
|
||||||
|
[IM_END_TOKEN],
|
||||||
|
]
|
||||||
|
return np.concatenate(joint)
|
||||||
|
per_row = np.full(width, IMAGE_PATCH_TOKEN)
|
||||||
|
if self.image_use_col_tokens:
|
||||||
|
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||||
|
joint = [
|
||||||
|
[IM_START_TOKEN],
|
||||||
|
np.tile(per_row, [height]),
|
||||||
|
[IM_END_TOKEN],
|
||||||
|
]
|
||||||
|
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
|
||||||
|
use_single_crop_col_tokens = (
|
||||||
|
self.image_use_col_tokens
|
||||||
|
if self.use_single_crop_col_tokens is None
|
||||||
|
else self.use_single_crop_col_tokens
|
||||||
|
)
|
||||||
|
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
|
||||||
|
if use_single_crop_col_tokens:
|
||||||
|
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||||
|
joint = [
|
||||||
|
[image_start_token],
|
||||||
|
np.tile(per_row, [resized_h]),
|
||||||
|
[IM_END_TOKEN],
|
||||||
|
] + joint
|
||||||
|
|
||||||
|
return np.concatenate(joint)
|
||||||
|
|
||||||
|
def get_video_string(
|
||||||
|
self,
|
||||||
|
video_grid: np.ndarray,
|
||||||
|
timestamps: np.ndarray,
|
||||||
|
):
|
||||||
|
if self.use_frame_special_tokens:
|
||||||
|
start_token_id = FRAME_START_TOKEN
|
||||||
|
end_token_id = FRAME_END_TOKEN
|
||||||
|
else:
|
||||||
|
start_token_id = IM_START_TOKEN
|
||||||
|
end_token_id = IM_END_TOKEN
|
||||||
|
|
||||||
|
num_frames, h, w = video_grid
|
||||||
|
video_string: str = ""
|
||||||
|
for frame_idx, frame_time in enumerate(timestamps):
|
||||||
|
# `per-frame-compact` time mode
|
||||||
|
prev_space = " " if frame_idx > 0 else ""
|
||||||
|
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
|
||||||
|
|
||||||
|
video_string += frame_prefix
|
||||||
|
per_row = np.full(w, IMAGE_PATCH_TOKEN)
|
||||||
|
if self.video_use_col_tokens:
|
||||||
|
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
|
||||||
|
extra_tokens = np.tile(per_row, [h])
|
||||||
|
video_tokens = [
|
||||||
|
[start_token_id],
|
||||||
|
extra_tokens,
|
||||||
|
[end_token_id],
|
||||||
|
]
|
||||||
|
video_string += "".join(np.concatenate(video_tokens, 0))
|
||||||
|
|
||||||
|
return video_string
|
||||||
|
|
||||||
|
def insert_bos(
|
||||||
|
self,
|
||||||
|
input_ids: np.ndarray,
|
||||||
|
attention_mask: np.ndarray,
|
||||||
|
bos_token_id: int,
|
||||||
|
pad_token_id: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_ids: [B, S] array with left padding
|
||||||
|
attention_mask: [B, S] array (0 for pad, 1 for valid)
|
||||||
|
bos_token_id: int
|
||||||
|
pad_token_id: int
|
||||||
|
Returns:
|
||||||
|
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
|
||||||
|
attention_mask_out: same shape as input_ids_out
|
||||||
|
"""
|
||||||
|
|
||||||
|
need_to_expand = len(input_ids.shape) == 1
|
||||||
|
if need_to_expand:
|
||||||
|
input_ids = input_ids[None, :]
|
||||||
|
attention_mask = attention_mask[None, :]
|
||||||
|
|
||||||
|
B, S = input_ids.shape
|
||||||
|
|
||||||
|
# Handle zero-length sequence
|
||||||
|
if S == 0:
|
||||||
|
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
|
||||||
|
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
|
||||||
|
if need_to_expand:
|
||||||
|
new_input_ids = new_input_ids[0]
|
||||||
|
new_attention_mask = new_attention_mask[0]
|
||||||
|
return new_input_ids, new_attention_mask
|
||||||
|
|
||||||
|
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
|
||||||
|
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
|
||||||
|
|
||||||
|
if bos_already_present:
|
||||||
|
if need_to_expand:
|
||||||
|
input_ids = input_ids[0]
|
||||||
|
attention_mask = attention_mask[0]
|
||||||
|
return input_ids, attention_mask
|
||||||
|
else:
|
||||||
|
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
|
||||||
|
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
|
||||||
|
|
||||||
|
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
|
||||||
|
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
|
||||||
|
tgt_idx = src_idx + 1 # shit right
|
||||||
|
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
|
||||||
|
|
||||||
|
# flatten valid_positions
|
||||||
|
flat_vals = input_ids[valid_mask]
|
||||||
|
flat_batch = batch_idx[valid_mask]
|
||||||
|
flat_tgt = tgt_idx[valid_mask]
|
||||||
|
|
||||||
|
new_input_ids[flat_batch, flat_tgt] = flat_vals
|
||||||
|
new_attention_mask[flat_batch, flat_tgt] = 1
|
||||||
|
|
||||||
|
insert_pos = first_valid_index
|
||||||
|
new_input_ids[np.arange(B), insert_pos] = bos_token_id
|
||||||
|
new_attention_mask[np.arange(B), insert_pos] = 1
|
||||||
|
|
||||||
|
if need_to_expand:
|
||||||
|
new_input_ids = new_input_ids[0]
|
||||||
|
new_attention_mask = new_attention_mask[0]
|
||||||
|
|
||||||
|
return new_input_ids, new_attention_mask
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||||
|
images: ImageInput = None,
|
||||||
|
videos: VideoInput = None,
|
||||||
|
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (`str`, `list[str]`, `list[list[str]]`):
|
||||||
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||||
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||||
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
||||||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. Both channels-first and channels-last formats are supported.
|
||||||
|
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
|
||||||
|
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
|
||||||
|
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
|
||||||
|
- `"timestamps"`: `np.ndarray` of shape (T,)
|
||||||
|
- `"sampled_fps"`: `float` (optional)
|
||||||
|
- `"sampling_augmentation"`: `str` (optional)
|
||||||
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||||
|
If set, will return tensors of a particular framework. Acceptable values are:
|
||||||
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||||
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`BatchFeature`: A [`BatchFeature`] with the following fields:
|
||||||
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||||
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||||
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
|
||||||
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||||
|
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
|
||||||
|
Returned when `images` is not `None`.
|
||||||
|
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
|
||||||
|
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
|
||||||
|
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||||
|
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
|
||||||
|
Returned when `videos` is not `None`.
|
||||||
|
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
MolmoAct2ProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if images is not None:
|
||||||
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
|
image_grids = image_inputs["image_grids"]
|
||||||
|
else:
|
||||||
|
image_inputs = {}
|
||||||
|
image_grids = None
|
||||||
|
|
||||||
|
if videos is not None:
|
||||||
|
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||||
|
video_grids = videos_inputs["video_grids"]
|
||||||
|
# If user has not requested video metadata, pop it
|
||||||
|
if "return_metadata" not in kwargs:
|
||||||
|
video_metadata = videos_inputs.pop("video_metadata")
|
||||||
|
else:
|
||||||
|
video_metadata = videos_inputs["video_metadata"]
|
||||||
|
else:
|
||||||
|
videos_inputs = {}
|
||||||
|
video_grids = None
|
||||||
|
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
text = text.copy() # below lines change text in-place
|
||||||
|
|
||||||
|
if image_grids is not None:
|
||||||
|
index = 0
|
||||||
|
for i in range(len(text)):
|
||||||
|
num_images = text[i].count(self.image_placeholder_token)
|
||||||
|
image_grids_i = image_grids[index : index + num_images]
|
||||||
|
for image_grid in image_grids_i:
|
||||||
|
image_tokens = self.get_image_tokens(image_grid)
|
||||||
|
image_string = "".join(image_tokens)
|
||||||
|
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
|
||||||
|
index += num_images
|
||||||
|
|
||||||
|
if video_grids is not None:
|
||||||
|
index = 0
|
||||||
|
for i in range(len(text)):
|
||||||
|
num_videos = text[i].count(self.video_placeholder_token)
|
||||||
|
assert num_videos in {0, 1}, "At most one video is supported for now"
|
||||||
|
video_grids_i = video_grids[index : index + num_videos]
|
||||||
|
metadata_i = video_metadata[index : index + num_videos]
|
||||||
|
for video_grid, metadata in zip(video_grids_i, metadata_i):
|
||||||
|
video_string = self.get_video_string(
|
||||||
|
video_grid,
|
||||||
|
metadata.timestamps,
|
||||||
|
)
|
||||||
|
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
|
||||||
|
index += num_videos
|
||||||
|
|
||||||
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
|
|
||||||
|
input_ids = text_inputs["input_ids"]
|
||||||
|
attention_mask = text_inputs["attention_mask"]
|
||||||
|
|
||||||
|
input_ids = np.array(input_ids)
|
||||||
|
attention_mask = np.array(attention_mask)
|
||||||
|
|
||||||
|
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
||||||
|
input_ids, attention_mask = self.insert_bos(
|
||||||
|
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
|
||||||
|
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
|
||||||
|
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
||||||
|
|
||||||
|
text_inputs["input_ids"] = input_ids.tolist()
|
||||||
|
text_inputs["attention_mask"] = attention_mask.tolist()
|
||||||
|
|
||||||
|
return BatchFeature(
|
||||||
|
data={**text_inputs, **image_inputs, **videos_inputs},
|
||||||
|
tensor_type=return_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_process_image_text_to_text(
|
||||||
|
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Post-process the output of the model to decode the text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||||
|
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||||
|
or `(sequence_length,)`.
|
||||||
|
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||||
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||||
|
**kwargs:
|
||||||
|
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`list[str]`: The decoded text.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(
|
||||||
|
generated_outputs,
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MolmoAct2Processor.register_for_auto_class()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user