Compare commits

...

24 Commits

Author SHA1 Message Date
Khalil Meftah acd31c7de2 fix(eval): use FeatureType enum comparison instead of string value 2026-06-16 15:22:50 +02:00
Khalil Meftah 240393d238 feat(eval): record eval rollouts as raw LeRobot datasets
- Record raw env observations inline during rollout(), before
preprocess_observation() transforms them. Uses LeRobotDataset.create()
with add_frame()/save_episode().

- Supports vectorized envs: each env in the batch records independently,
with save_episode() called per env on termination. Each task gets its
own dataset under output_dir/recordings/{task_group}_{task_id}/.

Enabled via --eval.recording=true; disabled by default.
2026-06-15 16:12:25 +02:00
Altman 8515d456be fix(datasets): avoid uint8 overflow in image stats (#3697)
* fix(datasets): avoid uint8 overflow in image stats

* fix(datasets): promote stats batches dynamically
2026-06-13 12:09:43 +02:00
Mahbod 30790de178 feat(edit-dataset): add concatenate_videos opt-out to merge (#3663)
* feat(edit-dataset): add `concatenate_videos` opt-out to merge

When merging datasets, source mp4s are concatenated into shards capped at
`video_files_size_in_mb` (default 200 MB). This is great for dataloader
throughput but destroys per-episode (or per-source) video boundaries,
which is undesirable when you want to inspect, ship, or reuse the
individual mp4s.

Add a `concatenate_videos: bool = True` knob plumbed through
`MergeConfig` → `merge_datasets` → `aggregate_datasets` → `aggregate_videos`.
When False, each source mp4 is copied 1:1 to its own destination mp4 with
no re-muxing, so the merge preserves source video boundaries.

Usage:

    lerobot-edit-dataset \
        --new_repo_id user/merged \
        --operation.type=merge \
        --operation.repo_ids "['user/a', 'user/b']" \
        --operation.concatenate_videos=false

Defaults are unchanged; the dataloader path is unaffected because the
`episodes.parquet` `from_timestamp`/`to_timestamp` index keeps working
regardless of whether each mp4 holds one or many episodes.

* feat(edit-dataset): extend concatenate opt-out to data files

Following review, add a concatenate_data flag mirroring concatenate_videos,
threaded through MergeConfig, merge_datasets, aggregate_datasets, aggregate_data
and append_or_create_parquet_file. Metadata index files still always concatenate.

Also trim the verbose docstrings and comments since the names are
self-explanatory, and extend the existing merge test to cover data files.
2026-06-12 20:05:04 +02:00
Pepijn cec8ee0be6 feat: language annotation pipeline (#3471)
Steerable annotation pipeline (lerobot-annotate) that populates the language_persistent and language_events columns introduced in PR 1 (#3467) directly into data/chunk-*/file-*.parquet.

This is PR 2 of the three-PR plan:

PR 1 (Add extensive language support #3467): schema + DSL + rendering, base of this PR
PR 2 (this PR): annotation pipeline writing into PR 1's columns
PR 3: model with language prediction and runtime
A VLM (Qwen-VL family, served on vLLM) watches each episode's video and emits grounded language annotations: subtasks, plans, memory, task rephrasings, interjections + speech, and per-camera VQA. The pipeline is built for production annotation at scale — single-camera grounding, embedded-frame inputs, a describe-then-segment grounding flow, and a deterministic full-episode coverage guarantee — informed by Scale's dense-captioning findings (representation > sampling, rules > reasoning, model capacity is the biggest lever, two-pass systems compound errors)
2026-06-12 15:12:33 +02:00
Nikodem Bartnik 02b315ab6a Docs/model card improvements (#3634)
* update policy deployment instruction with rollout

* add port and fix formatting

* add more base models to generate model card

* updated and extended model descriptions

* fix bug

* improved and extended structure

* exclude the templates from config

* add images and visualize dataset button

* add all policies we have docs for

* remove policies without the docs

* new fields, improved examples
2026-06-12 13:26:52 +02:00
Pepijn 234c768dfb feat(datasets): deterministic, resumable shuffling for EpisodeAwareSampler (#3769)
* fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync

In distributed training, accelerate can only synchronize the shuffle
permutation across ranks when the sampler exposes a generator attribute.
EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch
shards relied on every rank's global CPU RNG staying in lockstep forever;
any rank-asymmetric RNG consumption (e.g. eval rollouts on the main
process only) silently desynced the permutations and ranks trained on
overlapping/missing samples.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* fix(train): seed sampler generator and gate dataset download per node

- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so
  accelerator.prepare registers it as the synchronized RNG and the
  shuffle order is reproducible.
- Gate the initial make_dataset call on is_local_main_process instead of
  is_main_process: the global main process only exists on node 0, so on
  every other node all local ranks were downloading the dataset and
  building the Arrow cache concurrently.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* feat(datasets): add DeterministicEpisodeAwareSampler with O(1) memory and sample-exact resume

Add a sampler that never materializes frame indices: it stores only
per-episode boundaries (numpy, a few bytes per episode) and maps logical
positions to frame indices on the fly with searchsorted. Shuffling uses a
seeded Feistel permutation over [0, num_frames) (cycle-walking to the
exact domain), so the data order is a pure function of (seed, epoch):

- no RNG state to synchronize across distributed ranks,
- constant memory and zero epoch-boundary cost at any dataset size,
- O(1) seek to any position, enabling sample-exact resume.

Opt in with --deterministic_sampler=true. On resume, lerobot-train maps
the checkpointed step back to (epoch, start_index) via
compute_sampler_state and continues at the exact sample where the run
left off (up to accelerate's even_batches padding at epoch boundaries).
The shuffle is pseudo-random rather than a true uniform permutation, the
standard trade-off in large-scale training loaders.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* refactor(datasets): fold deterministic mode into EpisodeAwareSampler

Instead of a parallel DeterministicEpisodeAwareSampler class, extend the
existing EpisodeAwareSampler with a deterministic=True mode (seeded
Feistel permutation, epoch auto-advance, state_dict/load_state_dict).

The default mode is behavior-identical: same torch.randperm consumption
and the same generator contract accelerate synchronizes; the O(N) Python
index list is replaced by O(num_episodes) boundary arrays in both modes,
with `indices` kept as a back-compat property. Passing a generator
together with deterministic=True is rejected, and the state/seek methods
raise outside deterministic mode.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* feat(train): enable deterministic_sampler by default

Deterministic data order (sample-exact resume, no cross-rank RNG sync,
O(1) sampler memory) is now the default for map-style training; set
deterministic_sampler=false to restore the legacy RNG-based shuffle.
Streaming datasets ignore the flag (the sampler path only applies to
map-style datasets), replacing the previous hard validation error so
streaming configs keep working with the new default.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* feat(datasets): default EpisodeAwareSampler to deterministic mode and trim comments

deterministic=True is now the class default as well as the training
default; the legacy RNG path requires an explicit deterministic=False
(the train script's non-deterministic branch passes it). Docstrings and
inline comments slimmed down across the changed files.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* test(sampler): drain resumed trillion-frame sampler via iter() to avoid list() prealloc

list(sampler) calls PyObject_LengthHint -> __len__ (the full 10**12 epoch length) and
preallocates that many slots before iterating, OOMing even though the resumed epoch only
yields 3 frames. Collect through the iterator (no length hint) so the test exercises the
real O(1) seek/drain instead of CPython's list growth heuristic.

* fix(datasets): guard Feistel cycle-walking loop against non-convergence

Replace the unbounded while True in EpisodeAwareSampler._permute with a
bounded for loop capped at _MAX_CYCLE_WALK_STEPS (100) and raise
RuntimeError if the cycle-walk fails to land in [0, num_frames). The
loop is expected to converge in <4 steps on the chosen power-of-two
domain, so the bound is a safety net that should never trip in practice
but prevents a pathological infinite loop.

https://claude.ai/code/session_01HQ15tFrBsHYScjGWosEv22

* fix(datasets): make deterministic-sampler resume robust to world-size changes

compute_sampler_state mapped a checkpointed step back to (epoch, start_index)
using the *current* num_processes, but the number of sampler positions a step
consumes scales with the world size that produced it. Resuming on a different
GPU count therefore landed on the wrong epoch/offset, silently re-seeing or
skipping data.

Record num_processes in training_step.json at checkpoint time and feed the
checkpoint's value into compute_sampler_state on resume, so the data order
resumes at the right position regardless of the new world size. Warn when the
world size changed (the global offset is correct, but per-rank sample-exactness
needs the same topology). Old checkpoints without the field fall back to the
current world size.

Also document compute_sampler_state's assumptions explicitly: num_processes /
batch_size must match the checkpointing run, and accelerate's even_batches=True
padding is mirrored by the ceil(... / num_processes) term.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* style: apply ruff-format to lerobot_train.py

Collapse the compute_sampler_state(...) call onto one line so the
ruff-format pre-commit hook passes (fixes the failing CI check).

Co-authored-by: Cursor <cursoragent@cursor.com>

* refactor(datasets): use seeded torch.randperm instead of Feistel in EpisodeAwareSampler

Drop the Feistel permutation (and its SplitMix64 hash / cycle-walking) in favor of a
torch.randperm seeded from (seed, epoch). The deterministic mode keeps its key properties
- data order is a pure function of (seed, epoch), so it reproduces on every rank with no
  global-RNG synchronization, and
- state_dict / load_state_dict still resume sample-exactly, now by regenerating the epoch's
  permutation and slicing from the saved offset.

Construction stays O(num_episodes) (only episode boundaries are stored, never a per-frame
index list). The trade-off vs Feistel: the per-epoch shuffle is again O(num_frames) memory
(the randperm tensor) and no longer O(1)-seekable, in exchange for ~30 fewer LOC and a truly
uniform shuffle. Tests updated: the trillion-frame O(1) test is replaced with a
boundary-storage check and a scale resume-exactness test.

Co-authored-by: Cursor <cursoragent@cursor.com>

* refactor(datasets): make EpisodeAwareSampler always deterministic

With Feistel gone, deterministic and legacy modes were both just torch.randperm and the
deterministic path strictly dominated (reproducible across ranks via the (seed, epoch) seed,
no accelerate generator sync, resumable). Collapse to a single path and drop the redundant
flag:

- remove the `deterministic` and `generator` constructor args, `_iter_default`, and
  `_require_deterministic`; `set_epoch` / `state_dict` / `load_state_dict` are now unconditional
- remove the `deterministic_sampler` train config field and the legacy generator branch in
  lerobot_train.py (non-streaming map datasets always use the sampler)
- drop the now-obsolete generator/legacy tests

Note: removes the `generator` kwarg from EpisodeAwareSampler (back-compat break vs main); the
order is now a pure function of (seed, epoch), so no cross-rank RNG sync is needed.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(datasets): address sampler review (batch_size resume guard + docs)

- Record batch_size in training_step.json alongside num_processes and feed
  the checkpoint's value into compute_sampler_state on resume; warn when it
  differs (per-rank sample-exactness needs the same batch size).
- Document the set_epoch vs __iter__ auto-advance coupling on EpisodeAwareSampler
  (callers should rely on exactly one mechanism per run).
- Note the broadened (reproducibility-breaking) sampler guard and the no-generator
  distributed sharding correctness in lerobot_train.py.
- Add load_training_batch_size + parallel tests.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(train): download dataset once on the global main process

Gate the training dataset download on the global is_main_process (download once to the
shared dataset root, barrier, then every other rank reads the already-populated copy)
instead of per-node is_local_main_process. LeRobotDataset skips its snapshot_download
when try_load() succeeds, so no rank re-downloads. Assumes the dataset root / HF cache is
on storage shared across nodes.

Co-authored-by: Cursor <cursoragent@cursor.com>

* chore(datasets): trim sampler comment and drop duplicate tests

Remove the verbose dataloader-guard comment and the two EpisodeAwareSampler tests
that duplicated existing validation/warning coverage (no coverage loss).

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-12 11:47:16 +02:00
Caroline Pascal 0e9bd9e6fb feat(trim): adding optional trimming option in reencode_video (#3779)
* feat(trim): adding optional trimming option in reencode_video

* tests(trim): add triming test

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-12 11:29:26 +02:00
Steven Palma 87242cfced chore(dependecies): relax grpc-related bounds (#3777)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2026-06-11 19:13:14 +02:00
Steven Palma 1edc83a0ef feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup (#3773)
* feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup

* chore: address feedback
2026-06-11 19:07:28 +02:00
Steven Palma 6fbcf67249 chore: update readme (#3774)
* chore: update readme

* chore: update authors in project readme
2026-06-11 18:17:26 +02:00
Pepijn 41166b39fb fix(train): synchronize EpisodeAwareSampler shuffling across ranks and gate dataset download per node (#3768)
* fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync

In distributed training, accelerate can only synchronize the shuffle
permutation across ranks when the sampler exposes a generator attribute.
EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch
shards relied on every rank's global CPU RNG staying in lockstep forever;
any rank-asymmetric RNG consumption (e.g. eval rollouts on the main
process only) silently desynced the permutations and ranks trained on
overlapping/missing samples.

* fix(train): seed sampler generator and gate dataset download per node

- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so
  accelerator.prepare registers it as the synchronized RNG and the
  shuffle order is reproducible.
- Gate the initial make_dataset call on is_local_main_process instead of
  is_main_process: the global main process only exists on node 0, so on
  every other node all local ranks were downloading the dataset and
  building the Arrow cache concurrently.
2026-06-11 11:07:42 +02:00
Steven Palma 79c6821407 chore(dependecies): update mujoco transitives (#3756) 2026-06-10 12:58:55 +02:00
Steven Palma 507083249f Revert "fix(pyproject): adding ceiling bound on mujoco (<3.9.0) (#3751)" (#3754)
This reverts commit bd22407d93.
2026-06-10 10:38:42 +02:00
Caroline Pascal bd22407d93 fix(pyproject): adding ceiling bound on mujoco (<3.9.0) (#3751)
* fix(pyproject): adding ceiling bound on mujoco (<3.9.0)

* chore(uv.lock): updating uv.lock

* fix(linux): adding missing linux dependencies

* chore(uv.lock): updating uv.lock
2026-06-09 23:31:43 +02:00
Adil Zouitine 49755a3d9e feat(processor): Add in-memory processor pipeline serialization (#3732)
* feat(processor): add in-memory pipeline serialization

Expose processor pipeline config and tensor state without requiring temporary files, so processors can be transported, compared, or hashed directly in memory.

* feat(processor): enhance DataProcessorPipeline with registry support

- Added a new RegisteredLazyTensorStateStep for registry-based serialization tests.
- Improved state filename handling in _get_state_filename method.
- Refactored validation logic in _validate_loaded_config to simplify parameter types.
- Updated tests to verify registry step functionality and ensure correct state loading.

* refactor(processor): update state handling in DataProcessorPipeline

- Introduced a new static method _get_state_key to derive in-memory state keys from serialized filenames.
- Updated state_dict and load_state_dict methods to use suffixless state keys instead of filenames.
- Adjusted related tests to reflect changes in state key handling, ensuring consistency in state management

* fix(processor): update loaded_config argument description in DataProcessorPipeline

- Clarified the documentation for the loaded_config parameter to indicate that it may be a non-dictionary value, enhancing understanding for future developers.
2026-06-08 11:27:24 +02:00
Maxime Ellerbach 09808183ca feat(rollout): adding episodic strategy (#3717)
* feat(rollout): adding legacy strategy

* adding legacy to existing tests

* updating docs and docstring

* changing misleading docstring

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding extra guard like dagged with try except finally

* Potential fix for pull request finding

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding reset to initial position

* moving smooth teleop handover to control_utils and adding this behavior to legacy strategy

* reducing duration of the handover

* * renaming to episodic
* changing semantics of the docstring
* fixing leader - follower handover disable torque
* adding optionnal config to disable handover

* wiring the smooth_leader_follower_handover config

* renaming config smooth_leader_to_follower_handover

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-06-06 00:32:38 +02:00
Maxime Ellerbach 2e9cd87bbd feat(policies): add VLA-JEPA (#3568)
* first commit

* feat(policies): add VLA-JEPA

* feat(policies): add VLA-JEPA

* support vla_jepa

* (feat)policies: add VLA-JEPA

* linting

* adding deps to pyproject.toml

* updating uv lock

* adding guards to avoid needing transformers and diffusers for type checking and basic tests

* fixing action and state dim

* fix warnings with qwen processor kwargs

* fixing wm_loss not propagating

* adjusting obs steps, tublets size to match original implementation

* some more fixes to be closer to the original implem

* adding more tests to ensure good coverage

* align VLA-JEPA architecture with original checkpoint

- Remove stale `action_num_heads` / `action_attention_head_dim` config fields;
  DiT head dimensions are now always derived from the preset (DiT-B/L/test).
- Add `num_target_vision_tokens` and `action_max_seq_len` config fields required
  by the action head's future-token embedding and positional embedding tables.
- Fix default `qwen_model_name` to 2B (matches all released checkpoints).
- Rename `ActionEncoder` attrs w1/w2/w3 → layer1/layer2/layer3 to match
  checkpoint key names; replace `nn.Sequential` decoder/state-encoder with
  `_MLP2` (layer1/layer2 naming).
- Fix `VLAJEPAActionHead` to size ActionEncoder and StateEncoder at `inner_dim`
  (DiT input width) rather than `action_hidden_size` (DiT output width).
- Rename `DiT.blocks` → `transformer_blocks` and `attn` → `attn1` to match
  checkpoint; add alternating cross/self attention (even blocks cross-attend to
  Qwen context, odd blocks self-attend).
- Add `DiT-test` preset for unit tests.
- Rewrite `ActionConditionedVideoPredictor` with explicit ViT-style blocks
  (`_PredictorBlock` with fused qkv) to match checkpoint structure; rename
  `encoder`/`norm`/`proj` → `predictor_blocks`/`predictor_norm`/`predictor_proj`.

* propagate action_is_pad masking through VLA-JEPA policy pipeline

Pass the `action_is_pad` tensor from the batch through to the action head
so padded timesteps are excluded from the flow-matching loss.

* update VLA-JEPA tests for arch changes and action_is_pad

- Switch conftest to use `action_model_type="DiT-test"` now that
  `action_num_heads` / `action_attention_head_dim` have been removed.
- Add action_head tests covering fully-padded loss (zero) and equivalence
  of action_is_pad=None vs all-zeros mask.
- Remove obsolete `test_native_to_lerobot_wm_only` test.

* add VLA-JEPA documentation

Covers architecture overview, pretrained checkpoints, config reference,
training/eval commands for LIBERO-10, and guidance on fine-tuning for
single-camera datasets.

* add one-shot script to convert ginwind/VLA-JEPA checkpoints to safetensors (will remove once migrated)

* make default params more aligned with paper and pretrained models
- adding possibility of freezing qwen backbone and world model
- added tests for weight loading

* trying out to re-init the action head to avoid pretraining dimension mismatch

* allow different state dim and action dim

* removing missleading future_action_window_size to just use chunk_size

* lots of changes to make existing weights work, need to massively refactor the pre and post processing

* refactoring into using pre and post processor

* pre-commit cleanup

* fixing doc defaults args

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adressing dtype zeros issue

* adding guard for diffusers

* fixing training and exal examples

* trying to close success rate gap

* fix qwen norm layer output libero eval is now as expected

* adding instructions for different embodiement + fixing some tests

* smol fix to avoid having default CPU device when training

* fixing misconception about multiview / singleview handling

* removing conversion script

* adding licences

* adding .mdx docs and shortening polivy_vla_jepa_README.md

* removing useless pre-processor

* cleanup

* removing swish in favor of silu

* adding configuration gripper index and threshold

* fixing simlink

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: ginwind <ginwind@mail.ustc.edu.cn>
2026-06-04 19:22:51 +02:00
Jaimin d1b1c5c8cf docs: fix broken dataset script paths (datasets/v30 -> scripts) (#3695)
The docs pointed at src/lerobot/datasets/v30/, which does not exist.
Both scripts actually live in src/lerobot/scripts/:

- convert_dataset_v21_to_v30.py
- augment_dataset_quantile_stats.py

Updated the four references (one python -m module path and three
file-path invocations) to the correct location, matching each
script's own usage docstring.
2026-06-03 14:48:19 +02:00
Nikodem Bartnik 741c2d0a39 Docs/add lelab (#3707)
* first text draft (no images)

* simplified docs

* fix formatting

* add youtube video

* add a tip about compatibility

* fix broken link
2026-06-03 14:22:05 +02:00
Haoming Song 19fe315971 fix(train): enable relative action overrides for pretrained processors (#3711)
* fix(train): enable relative action overrides for pretrained processors
Keep pretrained processor pipelines when use_relative_actions is enabled and
apply relative/absolute action processor settings through overrides. Rename the
relative action processor registry key to relative_actions_processor.

* fix(config): reject rename_map without pretrained checkpoint

Fail fast when rename_map is set during fresh initialization, since fresh
configs derive feature names from the current dataset and no rename is applied.

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-03 11:46:35 +02:00
Khalil Meftah 906b585826 fix(datasets): default private to None in push_to_hub to respect Hub org visibility settings (#3713) 2026-06-02 19:25:13 +02:00
Khalil Meftah b8ad81bf39 feat(rewards): add ROBOMETER reward model (#3627)
* feat/add ROBOMETER reward model

* feat(rewards): add Robometer offline progress labeling script

* fix(rewards/robometer): add missing input keys mm_token_type_ids

* chore(rewards/robometer): default to lerobot/Robometer-4b model

* doc(rewards/robometer): update citation and original github link

* feat(rewards/robometer): add image key argument to compute Robometer progress
2026-05-29 21:45:39 +02:00
Haoquan Fang 24017e960c Add MolmoAct2 policy (#3604)
* add molmoact2 policy

* add apache headers to molmoact2 files

* simplify molmoact2 package imports

* align molmoact2 feature validation with eo pattern

* remove molmoact2 processor override from factory

* guard molmoact2 transformers imports

* guard molmoact2 processor transformers import

* add scipy dependency to molmoact2 extra

* use a single molmoact2 action queue

* move molmoact2 config logic into config

* fix molmoact2 hf image key resolution

* load molmoact2 without remote code

* lazy import molmoact2 scipy

* format molmoact2 files

* skip molmoact2 tests without optional deps

* fix molmoact2 pre-commit checks

* validate molmoact2 gripper range
2026-05-27 18:58:37 +02:00
131 changed files with 27686 additions and 341 deletions
+3
View File
@@ -65,6 +65,9 @@ repos:
name: Format Markdown with Prettier
types_or: [markdown, mdx]
args: [--prose-wrap=preserve]
# Jinja2 model-card templates use a .md extension but contain {% ... %} /
# {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops).
exclude: ^src/lerobot/templates/.*\.md$
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
+6
View File
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
--env.episode_length=5 \
--eval.n_episodes=1 \
--eval.batch_size=1
# E2E annotation pipeline smoke test against a tiny in-memory fixture
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
# backend, so it does not require a real model checkpoint or GPU.
annotation-e2e:
uv run python -m tests.annotations.run_e2e_smoke
+10 -7
View File
@@ -58,7 +58,7 @@ action = model.select_action(obs)
robot.send_action(action)
```
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
@@ -101,11 +101,13 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| Category | Models |
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
@@ -133,6 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
## Citation
@@ -140,7 +143,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
+10
View File
@@ -9,6 +9,8 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: lelab
title: LeLab - Lerobot GUI
- local: bring_your_own_policies
title: Adding a Policy
- local: integrate_hardware
@@ -43,6 +45,8 @@
title: Language Columns and Recipes
- local: tools
title: Tools
- local: annotation_pipeline
title: Annotation Pipeline
- local: video_encoding_parameters
title: Video encoding parameters
- local: streaming_video_encoding
@@ -59,6 +63,10 @@
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: molmoact2
title: MolmoAct2
- local: vla_jepa
title: VLA-JEPA
- local: eo1
title: EO-1
- local: groot
@@ -73,6 +81,8 @@
- sections:
- local: sarm
title: SARM
- local: robometer
title: ROBOMETER
- local: topreward
title: TOPReward
title: "Reward Models"
+291
View File
@@ -0,0 +1,291 @@
# Annotation Pipeline
`lerobot-annotate` watches each episode's video with a vision-language
model (VLM) and writes natural-language annotations back into your
dataset. It fills the two language columns from the
[Language Columns and Recipes](./language_and_recipes) page —
`language_persistent` and `language_events` — straight into
`data/chunk-*/file-*.parquet`.
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
memory, interjections, speech, and visual Q&A that a policy can be
trained on.
## How it fits together
```text
your dataset lerobot-annotate
(LeRobot v3.1)
┌─────────────────────────────────────────────────────┐
│ read episodes │
└──────────────────────────┬──────────────────────────┘
┌────────────────────┼────────────────────┐
▼ ▼ ▼
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
└────────────────────┼─────────────────────┘
│ each module stages raw JSONL
▼ into .annotate_staging/
┌─────────────────┐
│ validator │ ◀── checks everything
└────────┬────────┘
┌─────────────────┐
│ writer │
└────────┬────────┘
data/chunk-*/file-*.parquet
(+ meta/info.json tools)
```
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
VLM. Each module stages its output to disk, a validator checks it, and a
single writer rewrites the dataset shards in place.
## What the pipeline produces
Each module emits a few kinds of annotation ("styles"), routed to one of
the two language columns:
| Style / atom | Column | Module |
| ------------------------------------------- | --------------------- | --------------- |
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
| `interjection` | `language_events` | `interjections` |
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
### How subtasks are generated
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
it uses a two-step **describe → segment** flow:
1. **Describe** — the VLM narrates only what it actually sees in the
chosen camera (no guessing about the task).
2. **Segment** — that description is fed back in, and the VLM splits the
episode into consecutive atomic subtasks.
Both passes see the episode as **timestamped contact sheets** — frames
sampled at `frames_per_second` (0.5s by default) and packed into JPEG
grids with each frame's time burned into its corner, so the VLM cites
exact boundary times directly. This is far cheaper in vision tokens than
one image per frame, so the sampling can stay dense; episodes longer than
`max_frames_per_prompt` are split into windows at the same density and
merged. Both prompts also carry a causal **event-boundary** definition (a
new event starts when an object becomes held / is released / reaches a new
location / a lid changes state / contents move) to sharpen where cuts land.
The resulting spans are then stitched into a gap-free, full-episode
cover, so **every frame has exactly one active subtask**. See
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
for the production settings (single camera, timestamped contact sheets,
auto-windowed subtask generation).
### Tools
The writer does **not** add a `tools` column to the parquet. The tool
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
After every run, the pipeline makes sure the canonical `say` schema is in
that list, keeping any tools you declared beforehand.
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
pipeline preserves whatever is already there. That makes the tool visible
to the chat template, so the model can learn to _generate_ the call. The
runtime layer that actually _executes_ a generated call (the `Tool`
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
this PR — the [Tools](./tools) doc marks those pieces as
not-yet-implemented.
## Running on Hugging Face Jobs
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
The repo ships a launcher script you copy and tweak for your dataset:
```bash
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
```
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
that:
1. installs `lerobot` (from `main`) plus the annotation extras,
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
drives it over the OpenAI-compatible API,
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
with `lerobot-annotate`,
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
back to `--repo_id` in place if you leave that unset).
To use a different dataset, model, or hub repo, edit the `CMD` block in
the script. Every flag there maps directly to a `lerobot-annotate` flag
(run `lerobot-annotate --help` for the full list).
## Key options
These are the flags you'll reach for most often. Run
`lerobot-annotate --help` for everything else; the defaults are tuned for
short manipulation episodes.
### Dataset in / out
| Flag | Default | What it does |
| ----------------- | ------- | ----------------------------------------------------------------------- |
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
| `--root` | — | Annotate a local dataset directory instead. |
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
### Which modules run
Every module is on by default and can be toggled independently (set to
`false` to skip it, e.g. to iterate on one module at a time):
| Flag | Default | Turns off |
| ------------------------- | ------- | ----------------------------------- |
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
| `--interjections.enabled` | `true` | interjections + speech atoms |
| `--vqa.enabled` | `true` | the VQA pairs |
### The VLM (`--vlm.*`)
| Flag | Default | What it does |
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
| `--vlm.temperature` | `0.2` | Sampling temperature. |
### Subtasks / plan / memory (`--plan.*`)
| Flag | Default | What it does |
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). |
| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. |
| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). |
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. |
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
### Interjections + VQA
| Flag | Default | What it does |
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
## Contributing new modules
The pipeline is built to grow, and **contributions are very welcome** —
a brand-new module (say, trajectory traces or affordances), a new prompt
template, a smarter grounding flow, or quality fixes to the existing
`plan` / `interjections` / `vqa` modules.
Every module lives under
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
client and the keyframe cache, writes its raw output to the staging
tree, and plugs into the executor as its own phase. Got an idea? Open an
issue or PR on [the repo](https://github.com/huggingface/lerobot).
## How recipes consume the output
The annotations are meant to be read by recipes (see
[Language Columns and Recipes](./language_and_recipes)). Typically:
- low-level / high-level / memory-update branches read
`subtask` / `plan` / `memory` from `language_persistent`.
- an interjection-response branch reads `interjection` events plus the
paired speech atom (merged into one assistant turn via `tool_calls_from`)
and the matching `plan` refresh at the same timestamp.
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
`language_events`.
## Why state and events are split
Two ideas shape the design:
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
`plan`, `memory`) apply to the whole episode and answer "what's true
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
the one frame whose timestamp matches. Timestamps are copied straight
from the source parquet — never recomputed in floating point.
2. **One VLM pass.** All three modules share a single VLM client (the
OpenAI-compatible client talking to the job's vLLM server), so you pay
for one model load per dataset, not three.
## Re-running a single module
Each module stages its raw output to
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
prompt iteration cheap: re-running one module overwrites only its own
JSONL, then the writer recomposes the final parquet. Disable modules you
don't want with `--plan.enabled=false` (and likewise
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
## What the validator checks
Before the writer runs, `StagingValidator` confirms:
- every event row lands exactly on a real frame timestamp;
- no speech / interjection pairs are left orphaned;
- `plan` is refreshed at every interjection timestamp;
- `memory` rows fall on subtask boundaries (a warning, not an error);
- each VQA assistant `content` is valid JSON in one of the
bbox / keypoint / count / attribute / spatial shapes;
- every row goes to the column chosen by `column_for_style(style)`.
Any error aborts the writer. Pass `--skip_validation=true` to override
while debugging.
## Where each module's ideas come from
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
for atom granularity ("pick up one piece of lettuce", "place bowl to
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
for "how, not what" detail.
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
keep only the minimal relevant information — preserve outcomes, drop
specific attributes.
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
situated correction, specific constraint, preference. Speech is a
tool-call-only atom
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
keypoints) and Steerable VLA Policies
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
grounding. Pi0.7 also grounds answers across abstraction levels.
When improving a module, tweak its prompt template in
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
rewriting from scratch.
## Roughly how much it costs
Per episode, the pipeline makes about `max_steps` plan calls,
`max_interjections_per_episode` interjection calls, and
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
~50 VLM calls.
Storage stays small: `language_persistent` is at most tens of KB per
episode (parquet dictionary-encodes the one entry that repeats across
frames), and `language_events` is empty on most frames — its size scales
with the number of emissions, not `num_frames × num_emissions`.
+1
View File
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
- `episodic`: Episode-oriented policy recording with reset phases between episodes
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+38
View File
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
### Episodic (`--strategy.type=episodic`)
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
```bash
lerobot-rollout \
--strategy.type=episodic \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so100_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/my_eval_data \
--dataset.num_episodes=20 \
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10 \
--dataset.single_task="Pick up the red cube"
```
Teleop is optional — if omitted the robot holds its position during the reset phase.
**Keyboard controls:**
| Key | Action |
| ----------- | -------------------------------- |
| `→` (right) | End the current episode early |
| `←` (left) | Discard episode and re-record it |
| `ESC` | Stop the recording session |
| Flag | Description |
| ----------------------------------------------- | -------------------------------------------------------------------------- |
| `--dataset.num_episodes` | Number of episodes to record |
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
---
## Inference Backends
+29
View File
@@ -0,0 +1,29 @@
# LeLab - LeRobot Guide
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
> [!TIP]
> For now LeLab is compatible only with SO-ARM101
<Youtube id="VqyKUuW9V1g" />
### Installation
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
```
uv tool install git+https://github.com/huggingface/leLab.git && lelab
```
After install, run `lelab` from your terminal anytime to start the app.
### Features
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
+1 -1
View File
@@ -275,7 +275,7 @@ A converter aggregates perepisode files into larger shards and writes episode
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
# Convert an existing v2.1 dataset hosted on the Hub:
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
```
**What it does**
+433
View File
@@ -0,0 +1,433 @@
# MolmoAct2 Policy
MolmoAct2 is the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
training, evaluation, checkpointing, and dataset interfaces for easier use with
LeRobot datasets.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## Installation Requirements
Install LeRobot with the MolmoAct2 optional dependencies:
```bash
pip install -e ".[molmoact2]"
```
To run the models in this repository, you need an NVIDIA GPU. The measurements
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
`gradient_checkpointing=true` and include the forward pass, backward pass,
gradient clipping, optimizer step, and optimizer state allocation. Values are
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
dataloader workers, CUDA context, and fragmentation.
Multi-GPU training through `accelerate` increases throughput and global batch
size, but this LeRobot port does not currently expose the original MolmoAct2
`fsdp_devices` model-parallel training path. The current training script has
not been tested for multi-node training.
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
The repo has been tested with Ubuntu 22.04.
## Usage
To use MolmoAct2 in a LeRobot training config, set:
```python
policy.type=molmoact2
```
## Training
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
the same LeRobot training loop, dataset transforms, checkpoint saving, and
logging. The difference is only how the initial policy weights and processor
state are loaded.
### Training With Original MolmoAct2 Weight
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
the original HF model files, then build its own policy processor from the
dataset metadata and the policy options below.
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
augmentation, and LeRobot's checkpointing/logging path.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.setup_type="single franka robotic arm in libero" \
--policy.control_mode="delta end-effector pose" \
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--policy.freeze_embedding=true \
--policy.normalize_gripper=false \
--policy.enable_knowledge_insulation=false \
--policy.push_to_hub=false \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Training With LeRobot MolmoAct2 Weight
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
restores the saved LeRobot policy config, model weights, processor, and
normalization statistics. You can still override training-time options such as
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.path=/path/to/pretrained_model \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Common Practices
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
or a real-world dataset with less than 200 demonstrations, a global batch size of
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
cases, we intentionally keep the action expert fully trainable, which we found
to be crucial for model performance. For larger fine-tuning datasets, larger
global batch sizes and full fine-tuning are usually preferred.
### Common Policy Options
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
Use this for released MolmoAct2 weights.
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
created by LeRobot training.
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
`both`. `both` trains the flow-matching action expert and the discrete
action-token loss.
- `policy.train_action_expert_only`: trains only parameters whose names contain
`action_expert`. It requires `policy.action_mode=continuous`.
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
expert linear layers. When `policy.enable_lora_action_expert=false`, the
action expert base weights remain fully trainable while the VLM is trained
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
action expert is also adapter-tuned instead of fully fine-tuned.
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
context K/V states before the action loss. The default is `false`.
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
`10`. This LeRobot port overrides the loaded checkpoint's
`max_action_horizon` with this value.
- `policy.n_action_steps`: number of actions consumed from each predicted
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
- `policy.setup_type`: text inserted into the prompt to describe the robot and
scene, e.g. `single franka robotic arm in libero`. More examples are listed
in the `metadata_by_tag` entries of
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
- `policy.control_mode`: text inserted into the prompt to describe the action
space, e.g. `delta end-effector pose` or `absolute joint pose`.
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
processor.
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
example during training. We use `8` for fine-tuning.
- `policy.num_inference_steps`: optional override for continuous action
generation steps at inference time.
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
to reduce activation memory.
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
- `policy.normalize_gripper`: controls whether gripper dimensions are included
in state/action quantile normalization. The default is `false`.
- `policy.normalize_language`: normalizes task strings before prompt
construction. The default is `true`.
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
Released checkpoints use `policy.expected_max_action_dim=32`.
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
infer it from images, state dimension, action dimension, action horizon, and
discrete-action mode.
### Learning Rates
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
fine-tuning experiments.
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
`policy.optimizer_vit_lr=5e-6` for the vision tower,
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
`policy.enable_lora_action_expert=false`, so the action expert is still fully
fine-tuned with `policy.optimizer_action_expert_lr`. If
`policy.enable_lora_action_expert=true`, the action expert is trained through
LoRA adapters instead.
- Action-expert-only fine-tuning trains only the action expert and uses
`policy.optimizer_action_expert_lr=5e-5`.
You can override the full fine-tuning and action-expert learning rates with
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
### Dataset Quantile Statistics
MolmoAct2 defaults to quantile normalization for state and action features. If
your dataset has not been converted with quantile statistics, you can add them
with:
```bash
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
--repo-id=your_dataset
```
Alternatively, train MolmoAct2 with mean/std normalization:
```bash
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
```
## Evaluation
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
fixed and use `policy.per_episode_seed=true`.
**Important:** We found that `num_steps_wait=10` does not reliably let the
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
results reported here use `num_steps_wait=50`.
### Evaluation With LeRobot MolmoAct2 Weight
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
normalization statistics are restored together with the model.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
--policy.inference_action_mode=continuous \
--policy.model_dtype=bfloat16 \
--policy.use_amp=true \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
### Evaluation With Original MolmoAct2 Weight
You can evaluate a released Hugging Face checkpoint directly without first
converting it to a LeRobot checkpoint. In this case, set
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
normalization statistics, action horizon, prompt metadata, and image-key order
from the checkpoint's `norm_stats.json`.
To fully replicate the MolmoAct2 paper results with released Hugging Face
checkpoints, we recommend using the v0.5.1-pinned
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
branch. That branch matches the original evaluation settings used for the
reported numbers.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.norm_tag=libero \
--policy.inference_action_mode=continuous \
--policy.model_dtype=float32 \
--policy.use_amp=false \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_goal \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
full LIBERO suite. The same command works for other released MolmoAct2
checkpoints as long as the requested `policy.norm_tag` exists in that
checkpoint's `norm_stats.json`.
### Common Evaluation Options
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
flow-matching inference or `discrete` for action-token inference. It must be
compatible with the training-time `policy.action_mode` saved in the
checkpoint.
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
saved by LeRobot.
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
image-key order, and action horizon from the original checkpoint's
`norm_stats.json`. It is required for direct original-HF checkpoint
evaluation.
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
- `policy.use_amp`: runs the policy forward under autocast during eval. For
`model_dtype=bfloat16`, keep this enabled.
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
graph path for faster repeated continuous-action rollout.
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
action generation deterministic per episode for replication.
- `env.task`: comma-separated LIBERO suites or a single suite. Use
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
by the policy processor.
## Performance Results
### LIBERO Benchmark Results
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
compare and test its LeRobot implementation, we fine-tuned
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
results.
The LeRobot fine-tuned checkpoint reported here is available at
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
and was trained on
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
| -------------- | ---------------------: | -----------------: |
| LIBERO Spatial | 98.4% | 97.8% |
| LIBERO Object | 100.0% | 100.0% |
| LIBERO Goal | 98.0% | 97.8% |
| LIBERO 10 | 96.6% | 93.2% |
| Average | 98.25% | 97.20% |
These results demonstrate MolmoAct2's strong performance across diverse robotic
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
evaluation section.
## Differences From the Original Implementation
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
differences from the original training repository are:
- The original paper training stack loads the model in fp32 and trains under
mixed precision. This LeRobot port usually loads the checkpoint directly in
`policy.model_dtype=bfloat16` for lower memory use.
- The original repository uses its own FSDP/model-parallel training path. The
LeRobot port uses the standard LeRobot/Accelerate training path and has not
been tested for multi-node training.
- The original repository supports sequence packing. The LeRobot port trains on
one LeRobot sample per item and pads to an inferred fixed sequence budget.
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
dataset transforms, image augmentation, and Weights & Biases logging
conventions.
- The original training path supports mixed action horizons by padding to
`max_action_horizon` and masking padded horizon slots in the action expert
self-attention. This is useful when training across datasets with different
control frequencies. The LeRobot port currently targets single-dataset
fine-tuning, so `policy.chunk_size` overrides the checkpoint
`max_action_horizon` and horizon masking is not implemented yet. Support for
this mixed-horizon path is planned.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
+1 -1
View File
@@ -91,7 +91,7 @@ lerobot-train \
If your dataset is not converted with `quantiles`, you can convert it with the following command:
```bash
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
--repo-id=your_dataset \
```
+39
View File
@@ -0,0 +1,39 @@
# MolmoAct2
This repository contains the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
training, evaluation, checkpointing, and dataset compatibility.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## LIBERO Evaluation
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
scene stabilize and can degrade measured success. All LIBERO evaluation results
reported for this LeRobot implementation use `num_steps_wait=50`.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
+39
View File
@@ -0,0 +1,39 @@
# VLA-JEPA
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
---
## Architecture Overview
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
At inference time only the Qwen backbone and action head are used; the world model is not needed.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
+1 -1
View File
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
If you have existing datasets in v2.1 format, use the migration tool:
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
--repo-id your_id/existing_dataset
```
+185
View File
@@ -0,0 +1,185 @@
# ROBOMETER
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
**Project**: [robometer.github.io](https://robometer.github.io/)
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
## Overview
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
- **Success head**: predicts per-frame task success probability.
- **Preference head**: predicts which of two trajectories better completes the task during training.
The paper trains ROBOMETER with a composite objective:
```text
L = L_pref + L_prog + L_succ
```
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
## What the LeRobot Integration Covers
- Standard `reward_model.type=robometer` configuration through LeRobot.
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
- Dense, frame-level progress and success predictions internally.
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install the ROBOMETER dependencies:
```bash
pip install -e ".[robometer]"
```
If you use `uv` directly from a source checkout:
```bash
uv sync --extra robometer
```
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
## Model Inputs and Outputs
ROBOMETER expects:
- A trajectory video or sequence of frames.
- A natural-language task description.
In LeRobot datasets, the preprocessor reads:
| Config field | Default | Meaning |
| ------------------------- | ------------------------ | ----------------------------------------------------- |
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
## Usage
### Load the Reward Model Directly
```python
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
cfg = RobometerConfig(
pretrained_path="lerobot/Robometer-4B",
device="cuda",
reward_output="progress",
)
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
```
### Encode Frames and Compute a Reward
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
```python
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
# task: str
encoder = RobometerEncoderProcessorStep(
base_model_id=cfg.base_model_id,
use_multi_image=cfg.use_multi_image,
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
max_frames=cfg.max_frames,
)
encoded = encoder.encode_samples([(frames, task)])
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
reward = reward_model.compute_reward(batch)
```
`reward` is a tensor of shape `(batch_size,)`.
### Use the Reward Factory
You can also instantiate ROBOMETER through the reward factory:
```python
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
cfg = make_reward_model_config(
"robometer",
pretrained_path="lerobot/Robometer-4B",
device="cuda",
image_key="observation.images.top",
)
reward_model = make_reward_model(cfg)
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
```
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
## Configuration Notes
### Backbone and Vocabulary
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
```text
<|split_token|>
<|reward_token|>
<|pref_token|>
<|sim_token|>
<|prog_token|>
```
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
### Progress Prediction
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
### Success Prediction
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
## Limitations
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
## References
- [ROBOMETER project](https://robometer.github.io/)
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
- [Original ROBOMETER code](https://github.com/robometer/robometer)
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
## Citation
```bibtex
@inproceedings{liang2026robometer,
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
year={2026},
booktitle={Robotics: Science and Systems 2026},
}
```
## License
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
+235
View File
@@ -0,0 +1,235 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Default behaviour — view padding / trimming (no action required)**
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
+77
View File
@@ -0,0 +1,77 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
Spawns one single-GPU ``h200`` job that:
1. installs ``lerobot`` from ``main`` plus the annotation extras,
2. boots one vllm server with Qwen3.6-27B (dense VLM),
3. runs the plan / interjections / vqa modules across the dataset
in free-form mode (each episode generates its own subtasks +
memory),
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
or back to ``--repo_id``.
Usage:
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
run. For larger datasets, scale to ``h200x4`` and raise
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
"""
import os
from huggingface_hub import get_token, run_job
token = os.environ.get("HF_TOKEN") or get_token()
if not token:
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
CMD = (
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
"pip install --no-deps "
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
"pip install --upgrade-strategy only-if-needed "
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
"openai && "
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
"export VLLM_VIDEO_BACKEND=pyav && "
"lerobot-annotate "
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated "
"--push_to_hub=true "
"--vlm.backend=openai "
"--vlm.model_id=Qwen/Qwen3.6-27B "
"--vlm.num_gpus=1 "
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
"--tensor-parallel-size 1 --max-model-len 32768 "
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
"--vlm.serve_ready_timeout_s=1800 "
# Qwen3.6 ships with thinking on; annotation wants plain JSON answers.
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}'"
)
job = run_job(
image="vllm/vllm-openai:latest",
command=["bash", "-c", CMD],
flavor="h200",
secrets={"HF_TOKEN": token},
timeout="2h",
)
print(f"Job URL: {job.url}")
print(f"Job ID: {job.id}")
+42 -11
View File
@@ -115,8 +115,8 @@ dataset = [
]
training = [
"lerobot[dataset]",
"accelerate>=1.10.0,<2.0.0",
"wandb>=0.24.0,<0.25.0",
"wandb>=0.24.0,<0.28.0",
"lerobot[accelerate-dep]",
]
hardware = [
"lerobot[pynput-dep]",
@@ -142,7 +142,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
@@ -177,7 +178,12 @@ unitree_g1 = [
"lerobot[matplotlib-dep]",
"lerobot[pygame-dep]",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
reachy2 = [
"reachy2_sdk>=1.0.15,<1.1.0",
"grpcio<=1.73.1",
"protobuf<=6.32.0",
]
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
# leader (motorbridge-smart-servo / FashionStar UART servos).
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
@@ -198,7 +204,8 @@ wallx = [
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
"lerobot[transformers-dep]",
@@ -211,26 +218,43 @@ groot = [
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
# which talks to any OpenAI-compatible server (``vllm serve`` /
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
# (see examples/annotations/run_hf_job.py).
annotations = [
"lerobot[dataset]",
"lerobot[transformers-dep]",
"openai>=1.40,<2.0",
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
# uv's single unified lock would then cap ``torch`` for every extra
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
# install it locally only if you run your own ``vllm serve``.
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
@@ -275,10 +299,12 @@ all = [
"lerobot[multi_task_dit]",
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[molmoact2]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -289,6 +315,7 @@ all = [
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
@@ -311,6 +338,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
@@ -329,7 +357,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
[tool.setuptools.packages.find]
where = ["src"]
@@ -405,8 +433,11 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"arange",
"is_compileable",
"ROBOTIS",
"OT_VALUE"
"OT_VALUE",
"VanderBilt"
]
# TODO: Uncomment when ready to use
+15
View File
@@ -0,0 +1,15 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,36 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Steerable annotation pipeline producing ``language_persistent`` and
``language_events`` columns for LeRobot datasets.
The pipeline is decomposed into three independently runnable modules whose
outputs are staged per-episode before a final parquet rewrite:
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
"""
from .config import AnnotationPipelineConfig
from .validator import StagingValidator, ValidationReport
from .writer import LanguageColumnsWriter
__all__ = [
"AnnotationPipelineConfig",
"LanguageColumnsWriter",
"StagingValidator",
"ValidationReport",
]
@@ -0,0 +1,211 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class PlanConfig:
"""``plan`` module: subtasks + plan + memory + task augmentation."""
enabled: bool = True
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
n_task_rephrasings: int = 10
# Derive the task from video instead of episode_task: off / if_short / always.
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
derive_task_from_video: str = "if_short"
derive_task_min_words: int = 3
# --- Frame input: timestamped contact sheets (always on) ---------------
# The subtask describe/segment passes ALWAYS render the episode as
# macrodata/refiner-style contact sheets: sampled frames packed into JPEG
# grids with each frame's timestamp burned into its corner, so the VLM
# cites the exact source time of a boundary directly. This is far cheaper
# in vision tokens than one image per frame (≈2× faster subtask generation
# in practice), which is why the sampling is dense by default.
#
# ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s.
frames_per_second: float = 2.0
# Frame budget per VLM call (= columns × rows × sheets). When a whole
# episode sampled at ``frames_per_second`` exceeds this, the episode is
# AUTOMATICALLY split into consecutive windows of
# ``max_frames_per_prompt`` frames each (one describe→segment call per
# window, still at the full ``frames_per_second`` density), and the
# per-window spans are merged + stitched into one contiguous cover. So an
# episode of any length is always covered at the full sampling density.
max_frames_per_prompt: int = 60
contact_sheet_columns: int = 5
contact_sheet_frames_per_sheet: int = 20
contact_sheet_frame_width: int = 224
contact_sheet_quality: int = 84
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
# Narrate-only grounding pass before segmenting — best defense against subtasks
# invented from the task text (+1 VLM call/episode).
subtask_describe_first: bool = True
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
emit_plan: bool = True
# Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only.
# Symmetric counterpart of ``emit_plan``.
emit_memory: bool = True
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
@dataclass
class TaskAugAxesConfig:
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
enabled: bool = False
synonym_paraphrase: int = 3
omit_arm: int = 3
omit_orientation: int = 2
omit_grasp_method: int = 2
combined_omissions: int = 2
@dataclass
class InterjectionsConfig:
"""``interjections`` module: interjections + paired speech."""
enabled: bool = True
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
max_interjections_per_episode: int = 3
interjection_min_t: float = 2.0
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
interjection_window_seconds: float = 2.0
interjection_window_frames: int = 4
@dataclass
class VqaConfig:
"""``vqa`` module: general VQA."""
enabled: bool = True
vqa_emission_hz: float = 1.0
K: int = 1
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
# True: ground VQA only on --vlm.camera_key (default: every camera).
restrict_to_default_camera: bool = False
@dataclass
class VlmConfig:
"""Shared Qwen-VL client configuration."""
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
# auto_serve=True); ``stub`` is for tests.
backend: str = "openai"
model_id: str = "Qwen/Qwen3.6-27B"
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
api_base: str = "http://localhost:8000/v1"
api_key: str = "EMPTY"
# Spawn a server if none answers api_base; False = fail fast on a remote.
auto_serve: bool = True
serve_port: int = 8000
# Override the auto-serve command; ``{port}`` substituted per replica.
serve_command: str | None = None
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
parallel_servers: int = 1
num_gpus: int = 0
client_concurrency: int = 16
serve_ready_timeout_s: float = 600.0
max_new_tokens: int = 512
temperature: float = 0.2
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
max_model_len: int | None = None
# Camera for keyframes; None → first ``observation.images.*`` key.
camera_key: str | None = None
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
chat_template_kwargs: dict[str, Any] | None = None
@dataclass
class ExecutorConfig:
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
# Episodes processed concurrently per phase; main knob for saturating the servers.
episode_parallelism: int = 16
@dataclass
class AnnotationPipelineConfig:
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
# is on and ``new_repo_id`` unset.
repo_id: str | None = None
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
new_repo_id: str | None = None
root: Path | None = None
# Defaults to ``<root>/.annotate_staging/``.
staging_dir: Path | None = None
seed: int = 1729
plan: PlanConfig = field(default_factory=PlanConfig)
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
vqa: VqaConfig = field(default_factory=VqaConfig)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None
# Keyframe decode backend forwarded to ``decode_video_frames``. None →
# library default (torchcodec when available, else PyAV). Or pin
# ``"torchcodec"`` / ``"pyav"`` explicitly.
video_backend: str | None = None
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
push_to_hub: bool = False
push_private: bool = False
push_commit_message: str | None = None
def resolved_staging_dir(self, root: Path) -> Path:
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
@@ -0,0 +1,253 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""In-process executor that runs the annotation phases.
The executor runs **six phases** in dependency order:
phase 1: ``plan`` module (plan + subtasks + memory)
phase 2: ``interjections`` module (interjections + speech)
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: ``vqa`` module (VQA)
phase 5: validator
phase 6: writer
Phase 3 is why the ``plan`` module must be re-entered after the
``interjections`` module — to refresh ``plan`` rows at interjection
timestamps.
Distributed execution is provided by Hugging Face Jobs (see
``examples/annotations/run_hf_job.py``); the runner inside the job
invokes ``lerobot-annotate`` which uses this in-process executor.
Episode-level concurrency is controlled by
``ExecutorConfig.episode_parallelism``.
"""
from __future__ import annotations
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .config import AnnotationPipelineConfig
from .reader import EpisodeRecord, iter_episodes
from .staging import EpisodeStaging
from .validator import StagingValidator
from .writer import LanguageColumnsWriter
logger = logging.getLogger(__name__)
@dataclass
class PhaseResult:
"""Summary of one pipeline phase across all episodes."""
name: str
episodes_processed: int
episodes_skipped: int
@dataclass
class PipelineRunSummary:
"""Aggregated result returned by :meth:`Executor.run`."""
phases: list[PhaseResult]
written_paths: list[Path]
validation_report: Any # ValidationReport, kept Any to avoid import cycle
@dataclass
class Executor:
"""Run all six phases over a dataset root in-process.
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
(a thread pool); cluster-level concurrency comes from running this
executor inside a Hugging Face Job. Tests construct the executor
directly with stub modules.
"""
config: AnnotationPipelineConfig
plan: Any # PlanSubtasksMemoryModule
interjections: Any # InterjectionsAndSpeechModule
vqa: Any # GeneralVqaModule
writer: LanguageColumnsWriter
validator: StagingValidator
def run(self, root: Path) -> PipelineRunSummary:
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
n = len(records)
if n == 0:
raise ValueError(f"No episodes found under {root}/data/")
print(f"[annotate] {n} episodes total", flush=True)
staging_dir = self.config.resolved_staging_dir(root)
staging_dir.mkdir(parents=True, exist_ok=True)
phases: list[PhaseResult] = []
# Phase 1: ``plan`` module (plan + subtasks + memory)
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
# Phase 2: ``interjections`` module (interjections + speech). It
# reads the ``plan`` module's subtask rows from the same staging
# tree to ground the interjection prompt in the correct local subtask.
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: ``vqa`` module (VQA)
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
print("[annotate] running validator...", flush=True)
report = self.validator.validate(records, staging_dir)
if not report.ok and not self.config.skip_validation:
raise RuntimeError(f"Staging validation failed: {report.summary()}")
print(f"[annotate] validator: {report.summary()}", flush=True)
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
written = self.writer.write_all(records, staging_dir, root)
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
# Keep meta/info.json aligned with the parquet schema we just wrote.
# Idempotent and additive: existing user metadata is preserved.
self._ensure_annotation_metadata_in_info(root)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
@staticmethod
def _ensure_annotation_metadata_in_info(root: Path) -> None:
"""Write language features and canonical tools to ``meta/info.json``.
``LanguageColumnsWriter`` adds ``language_persistent`` and
``language_events`` to parquet shards. The metadata must advertise
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
cast against the old schema and fail on the extra parquet columns.
"""
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
info_path = root / "meta" / "info.json"
if not info_path.exists():
return
try:
info = load_info(root)
except Exception as exc: # noqa: BLE001
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
return
changed = False
merged_features = {**info.features, **language_feature_info()}
if merged_features != info.features:
info.features = merged_features
changed = True
existing = info.tools or []
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
info.tools = [*existing, SAY_TOOL_SCHEMA]
changed = True
if changed:
write_info(info, root)
print(
"[annotate] meta/info.json: "
f"language_features={list(language_feature_info())}, "
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
flush=True,
)
def _run_module_phase(
self,
name: str,
records: list[EpisodeRecord],
staging_dir: Path,
module: Any,
) -> PhaseResult:
if not module.enabled:
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
n = len(records)
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
print(
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
flush=True,
)
t0 = time.time()
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
i, record = idx_record
ep_start = time.time()
staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging)
return i, record.episode_index, time.time() - ep_start
processed = 0
if parallelism == 1:
for i, record in enumerate(records, 1):
_, ep_idx, elapsed = _do((i, record))
processed += 1
print(
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
flush=True,
)
else:
with ThreadPoolExecutor(max_workers=parallelism) as pool:
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
for fut in as_completed(futures):
i, ep_idx, elapsed = fut.result()
processed += 1
print(
f"[annotate] {name} episode {processed}/{n} "
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
flush=True,
)
total = time.time() - t0
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
def _run_plan_update_phase( # noqa: PLR0915
self, records: list[EpisodeRecord], staging_dir: Path
) -> PhaseResult:
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
The ``plan`` module owns the prompt; the ``interjections`` module
produced the timestamps. This phase therefore calls back into the
``plan`` module with the interjection timestamps so its existing
prompt path is reused.
"""
if not self.plan.enabled or not self.interjections.enabled:
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
interjection_rows = [
row for row in staging.read("interjections") if row.get("style") == "interjection"
]
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
if interjection_times:
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
processed += 1
# Episodes without any interjections are skipped (no plan refresh
# needed); count them so the summary's processed+skipped == total.
return PhaseResult(
name="plan_update",
episodes_processed=processed,
episodes_skipped=len(records) - processed,
)
@@ -0,0 +1,481 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keyframe extraction for the annotation pipeline.
Modules attach decoded camera frames to their VLM prompts so the model can
ground subtask decomposition, interjection scenarios, and VQA in actual
visual content. The pipeline shares one provider across modules and one
episode at a time, with a small per-episode cache so multiple modules
querying the same timestamp pay decode cost once.
"""
from __future__ import annotations
import io
import logging
import math
import threading
from collections.abc import Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
import PIL.Image
import torch
from lerobot.configs.video import VideoEncoderConfig
from lerobot.datasets.video_utils import decode_video_frames, reencode_video
from .reader import EpisodeRecord, snap_to_frame
logger = logging.getLogger(__name__)
class FrameProvider(Protocol):
"""Decodes camera frames at episode-relative timestamps."""
@property
def camera_keys(self) -> list[str]:
"""All ``observation.images.*`` feature keys this provider can decode."""
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
:func:`to_image_blocks` converts them to PIL only at the VLM-message
boundary.
Empty list if the camera is unavailable. ``camera_key=None`` falls back
to the provider's default camera so existing single-camera callers
(the ``plan`` and ``interjections`` modules) keep working unchanged.
"""
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
"""Return up to ``max_frames`` decoded frames covering the whole episode.
Sampling is uniform across the episode duration. Frames are
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
them into one ``{"type":"video", "video":<list>}`` block for a
Qwen-VL-compatible model that pools temporally itself. Empty list if
no camera available.
"""
@dataclass
class _NullProvider:
"""No-op provider used when the dataset has no video keys or in tests."""
@property
def camera_keys(self) -> list[str]:
return []
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
return []
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
return []
def null_provider() -> FrameProvider:
return _NullProvider()
@dataclass
class VideoFrameProvider:
"""Decodes frames from the dataset's ``observation.images.*`` streams.
By default the *first* camera key is used for the ``plan`` module
(subtask decomposition) and the ``interjections`` module (interjection
scenarios) — those prompts care about *what is happening*, not which
angle. The ``vqa`` module instead iterates over every camera in
:attr:`camera_keys` so each frame's
grounded answer (bbox/keypoint/...) is tagged with the camera it was
grounded against.
``camera_key`` overrides the default-camera choice but does not restrict
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
``video_for_episode`` to read a non-default stream.
Caches up to ``cache_size`` decoded frames per process to keep
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
"""
root: Path
camera_key: str | None = None
tolerance_s: float = 1e-2
cache_size: int = 256
# Keyframe decode backend forwarded to
# :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None``
# uses the library default (torchcodec when available, else PyAV).
video_backend: str | None = None
_meta: Any = field(default=None, init=False, repr=False)
_cache: dict = field(default_factory=dict, init=False, repr=False)
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
# one-shot warn flag against concurrent updates from worker threads.
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
# Serializes decode_video_frames calls: torchcodec hands out one
# ``VideoDecoder`` per file from a process-wide cache, and the decoder
# is not safe to drive from multiple threads at once.
_decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
def __post_init__(self) -> None:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
# Only ``video_keys`` are decodable here: the clip/decode paths read
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
# only for video-stored cameras. Image-stored cameras (also in
# ``camera_keys``) would KeyError, so restrict the list — and the
# default — to video keys.
keys = list(self._meta.video_keys)
# Last-resort fallback: if metadata didn't surface any video keys but
# the caller explicitly named a camera (``--vlm.camera_key=...``),
# trust them — the key is by definition known to exist on the dataset.
if not keys and self.camera_key:
keys = [self.camera_key]
self._camera_keys = keys
if self.camera_key is None:
self.camera_key = keys[0] if keys else None
@property
def camera_keys(self) -> list[str]:
"""All ``observation.images.*`` keys available on this dataset."""
return list(self._camera_keys)
def frames_at(
self,
record: EpisodeRecord,
timestamps: list[float],
camera_key: str | None = None,
) -> list[Any]:
target = camera_key if camera_key is not None else self.camera_key
if not timestamps or target is None:
return []
# Snap each request to the nearest real frame timestamp: callers
# sample uniform grids whose points land mid-frame, and
# ``decode_video_frames`` rejects queries farther than
# ``tolerance_s`` from a decodable frame. Snapping also dedupes
# repeat queries through the cache.
if record.frame_timestamps:
timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps]
out: list[Any] = []
misses: list[float] = []
miss_indices: list[int] = []
with self._lock:
for i, ts in enumerate(timestamps):
key = (record.episode_index, target, round(float(ts), 6))
cached = self._cache.get(key)
if cached is not None:
out.append(cached)
else:
out.append(None)
misses.append(float(ts))
miss_indices.append(i)
if misses:
decoded = self._decode(record.episode_index, misses, target)
# ``_decode`` returns exactly one frame per requested timestamp,
# or an empty list if decoding failed wholesale. A partial list
# would mean a frame/timestamp misalignment, so only pair them up
# when the counts match (``strict=True`` then guards regressions).
if len(decoded) == len(miss_indices):
with self._lock:
for i, frame in zip(miss_indices, decoded, strict=True):
out[i] = frame
key = (record.episode_index, target, round(float(timestamps[i]), 6))
if len(self._cache) >= self.cache_size:
self._cache.pop(next(iter(self._cache)))
self._cache[key] = frame
# filter out any None left over from decode failures
return [frame for frame in out if frame is not None]
def video_for_episode(
self,
record: EpisodeRecord,
max_frames: int,
camera_key: str | None = None,
) -> list[Any]:
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
The whole episode duration is covered; the model picks subtask
boundaries from the temporal pooling it does internally. Frames are
``torch.Tensor`` (see :meth:`frames_at`).
"""
target = camera_key if camera_key is not None else self.camera_key
if max_frames <= 0 or target is None or not record.frame_timestamps:
return []
n_frames = min(max_frames, len(record.frame_timestamps))
if n_frames == len(record.frame_timestamps):
timestamps = list(record.frame_timestamps)
else:
t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1]
if t_last <= t0:
timestamps = [float(t0)] * n_frames
else:
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
timestamps = [float(t0 + i * step) for i in range(n_frames)]
return self.frames_at(record, timestamps, camera_key=target)
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
Returns ``None`` if the dataset has no video tracks or extraction
failed. Skips re-extract when the cached clip already exists.
Re-encodes to H.264 via
:func:`lerobot.datasets.video_utils.reencode_video` so the resulting
mp4 is decodable by every downstream video processor — stream-copy
would inherit the source codec (often AV1 in modern LeRobot
datasets), which vllm's libav build cannot decode.
"""
if self.camera_key is None:
return None
cache_dir.mkdir(parents=True, exist_ok=True)
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
if out_path.exists() and out_path.stat().st_size > 0:
return out_path
ep = self._meta.episodes[record.episode_index]
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast")
try:
reencode_video(
src,
out_path,
camera_encoder=encoder,
overwrite=True,
start_time_s=from_timestamp,
end_time_s=to_timestamp,
)
except Exception:
logger.warning(
"clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True
)
return None
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
(torchcodec when available, PyAV otherwise; ``video_backend`` pins
one explicitly). Returns one frame per requested timestamp, or ``[]``
if decoding failed — callers treat ``[]`` as "no frames available".
"""
ep = self._meta.episodes[episode_index]
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
try:
# The module phases decode under a ThreadPoolExecutor (see
# ``ExecutorConfig.episode_parallelism``) but torchcodec's cached
# per-file decoder is single-threaded, so serialize decodes on a
# dedicated lock. Frame extraction is a small fraction of episode
# wall time (VLM calls dominate), so the contention is cheap.
with self._decode_lock:
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
decoded = decode_video_frames(
video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True
)
return list(decoded)
except Exception as exc:
# Log loudly the first time so a silent vqa-module no-op (every
# prompt skipped because frames_at returned []) is debuggable from
# the job log instead of post-hoc parquet inspection. Subsequent
# failures stay quiet.
with self._lock:
already_warned = self._warned_decode_fail
if not already_warned:
self._warned_decode_fail = True
if not already_warned:
logger.warning(
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s",
episode_index,
camera_key,
video_path,
self.video_backend,
exc,
exc_info=exc,
)
return []
def make_frame_provider(
root: Path, camera_key: str | None = None, video_backend: str | None = None
) -> FrameProvider:
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
try:
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
except Exception:
return null_provider()
if provider.camera_key is None:
return null_provider()
return provider
def _frame_to_pil(frame: Any) -> Any:
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
straight from :func:`decode_video_frames`); PIL is only created here, at
the VLM-message boundary, because the chat backends expect PIL images /
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
"""
if not isinstance(frame, torch.Tensor):
return frame
array = frame.detach().cpu()
if array.ndim == 3 and array.shape[0] in (1, 3):
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
if array.shape[-1] == 1:
array = array.squeeze(-1)
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
"""Wrap a list of decoded frames as one Qwen-VL video block.
Returns ``[]`` when the list is empty, so the caller can splat the result
into a content array without a separate emptiness check.
"""
if not frames:
return []
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
"""Wrap a video file URL as one ``video_url`` block.
Used by the ``openai`` backend (transformers serve / vllm serve /
ktransformers serve), where the server handles frame sampling.
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
"""
if not url:
return []
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image:
"""Burn ``timestamp`` (seconds) into the top-left corner of ``image``.
A solid black badge with white text, so a VLM reading a contact sheet can
cite the exact source time of each tile (e.g. ``012.50s``) directly,
instead of the caller having to map tile position back to time. Mirrors
the macrodata/refiner contact-sheet convention.
"""
from PIL import ImageDraw, ImageFont
result = image.copy()
draw = ImageDraw.Draw(result)
font = ImageFont.load_default()
label = f"{timestamp:06.2f}s"
left, top, right, bottom = draw.textbbox((0, 0), label, font=font)
text_w, text_h = right - left, bottom - top
pad = max(3, round(min(image.width, image.height) * 0.018))
draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0))
draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font)
return result
def to_contact_sheet_blocks(
frames: Sequence[Any],
timestamps: Sequence[float],
*,
columns: int = 5,
frames_per_sheet: int = 20,
frame_width: int = 224,
quality: int = 84,
) -> list[dict[str, Any]]:
"""Pack decoded frames into timestamped JPEG contact-sheet image blocks.
Each frame is resized to ``frame_width`` wide, stamped with its
episode-relative timestamp, and tiled row-major into grids of
``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}``
block is returned per grid; many frames collapse into a few images, so a
long episode's temporal coverage stays dense at a fraction of the vision
tokens N separate frames would cost. ``frames`` and ``timestamps`` must be
aligned and equal length. Returns ``[]`` for empty input.
"""
from PIL import Image
if not frames:
return []
columns = max(1, columns)
frames_per_sheet = max(1, frames_per_sheet)
rows_per_sheet = math.ceil(frames_per_sheet / columns)
tiles: list[PIL.Image.Image] = []
for ts, frame in zip(timestamps, frames, strict=False):
img = _frame_to_pil(frame)
if not isinstance(img, PIL.Image.Image):
continue
img = img.convert("RGB")
if img.width != frame_width:
height = max(1, round(img.height * frame_width / img.width))
img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR)
tiles.append(_draw_timestamp_badge(img, float(ts)))
if not tiles:
return []
blocks: list[dict[str, Any]] = []
for start in range(0, len(tiles), frames_per_sheet):
chunk = tiles[start : start + frames_per_sheet]
cell_w = max(tile.width for tile in chunk)
cell_h = max(tile.height for tile in chunk)
sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0))
for i, tile in enumerate(chunk):
x = (i % columns) * cell_w
y = (i // columns) * cell_h
sheet.paste(tile, (x, y))
# JPEG round-trip at ``quality`` to match the refiner convention and
# shrink the wire payload; vision-token count is set by resolution, so
# the real saving is the grid packing, not the codec.
buf = io.BytesIO()
sheet.save(buf, format="JPEG", quality=quality)
buf.seek(0)
blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")})
return blocks
@@ -0,0 +1,25 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .general_vqa import GeneralVqaModule
from .interjections_and_speech import InterjectionsAndSpeechModule
from .plan_subtasks_memory import PlanSubtasksMemoryModule
__all__ = [
"GeneralVqaModule",
"InterjectionsAndSpeechModule",
"PlanSubtasksMemoryModule",
]
@@ -0,0 +1,248 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``vqa`` module: general VQA at a timed cadence.
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
consecutive frames, and every anchored frame gets its own VQA pair. Each
pair is grounded on that single anchor frame — there is no per-pair frame
window. For datasets with multiple cameras, every anchored frame produces
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
generated against that camera's frame and stamped with the matching
``camera`` field on the emitted rows. The resolver disambiguates via
``camera=...``; recipes that consume VQA do so through one sub-recipe
per camera (see ``recipes/pi05_hirobot.yaml``).
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
count, attribute, spatial. The assistant's ``content`` is a JSON string
whose schema depends on the question type. Malformed JSON triggers one
retry inside :meth:`VlmClient.generate_json`.
"""
from __future__ import annotations
import json
import logging
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import VqaConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..validator import classify_vqa_answer
from ..vlm_client import VlmClient
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
"""Return the relative frame indices to anchor VQA emissions to.
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
consecutive frames starting at the tick. Ticks fall on the nearest
available source frame timestamp.
"""
if hz <= 0 or k <= 0 or not frame_timestamps:
return []
t0 = frame_timestamps[0]
t_last = frame_timestamps[-1]
period = 1.0 / hz
indices: list[int] = []
t = t0
while t <= t_last + 1e-9:
# find the index of the nearest frame to t
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
for offset in range(k):
j = nearest_i + offset
if j >= len(frame_timestamps):
break
if not indices or indices[-1] != j:
indices.append(j)
t += period
# dedupe while preserving order
seen: set[int] = set()
deduped: list[int] = []
for i in indices:
if i in seen:
continue
seen.add(i)
deduped.append(i)
return deduped
@dataclass
class GeneralVqaModule:
"""Emit grounded VQA pairs at a timed cadence."""
vlm: VlmClient
config: VqaConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
_warned_no_camera: bool = field(default=False, init=False, repr=False)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
if not record.frame_timestamps:
staging.write("vqa", [])
return
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
)
cameras = self._target_cameras()
if not cameras:
# No camera available — emit nothing rather than producing
# untagged rows that would fail validation. Surface a loud one-
# time warning so this is never silently a no-op.
if not self._warned_no_camera:
logging.getLogger(__name__).warning(
"vqa module found no cameras on the frame provider — "
"every episode will emit zero VQA rows. Check that the "
"dataset declares observation.images.* features in "
"meta/info.json; passing --vlm.camera_key=<key> at the "
"CLI now also seeds the cameras list as a fallback."
)
self._warned_no_camera = True
staging.write("vqa", [])
return
# Build all messages first (one per (frame, camera)), then issue them
# as a single batched generate_json call so the client can fan them
# out concurrently.
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
for idx in anchor_idx:
ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types)
for camera in cameras:
messages = self._build_messages(record, qtype, ts, camera)
# Skip cameras that decoded to zero frames at this ts: no point
# asking the VLM to ground a bbox without an image.
if not _has_image_block(messages):
continue
per_call.append((ts, camera, qtype, messages))
if not per_call:
staging.write("vqa", [])
return
results = self.vlm.generate_json([m for _, _, _, m in per_call])
rows: list[dict[str, Any]] = []
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
qa = self._postprocess(result)
if qa is None:
continue
question, answer = qa
rows.append(
{
"role": "user",
"content": question,
"style": "vqa",
"timestamp": ts,
"camera": camera,
"tool_calls": None,
}
)
rows.append(
{
"role": "assistant",
"content": json.dumps(answer, sort_keys=True),
"style": "vqa",
"timestamp": ts,
"camera": camera,
"tool_calls": None,
}
)
staging.write("vqa", rows)
def _target_cameras(self) -> list[str]:
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
Defaults to every camera the provider exposes. Datasets with no
cameras (or test/null providers) yield an empty list, which makes
``run_episode`` a no-op.
When ``config.restrict_to_default_camera`` is set, VQA grounds on
only the provider's default camera (the single ``--vlm.camera_key``
stream), matching the plan / interjection modules so the whole
pipeline focuses on one view.
"""
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
if getattr(self.config, "restrict_to_default_camera", False):
default = getattr(self.frame_provider, "camera_key", None)
if default and default in all_cameras:
return [default]
# ``restrict_to_default_camera`` is set but the configured default
# isn't one the provider exposes. Returning it anyway would make
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
# fall through to every available camera instead.
if default:
logging.getLogger(__name__).warning(
"restrict_to_default_camera is set but camera_key=%r is not in the "
"provider's cameras %s; grounding VQA on all available cameras instead.",
default,
all_cameras,
)
return all_cameras
def _build_messages(
self,
record: EpisodeRecord,
question_type: str,
frame_timestamp: float,
camera_key: str,
) -> list[dict[str, Any]]:
prompt = load_prompt("vqa").format(
episode_task=record.episode_task,
question_type=question_type,
)
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
if not isinstance(result, dict):
return None
question = result.get("question")
answer = result.get("answer")
if not isinstance(question, str) or not question.strip():
return None
if not isinstance(answer, dict):
return None
# The validator will enforce shape; here we just sanity-check that the
# answer matches *some* known shape so we can drop garbage early.
if classify_vqa_answer(answer) is None:
return None
return question.strip(), answer
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
"""Return True if any user content block is a populated image block."""
for msg in messages:
content = msg.get("content")
if not isinstance(content, list):
continue
for block in content:
if isinstance(block, dict) and block.get("type") == "image":
return True
return False
@@ -0,0 +1,211 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
Two sub-passes:
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
canonical task). No interjection row — the canonical task is already the
user utterance from ``meta/tasks.parquet``.
2. For mid-episode interruptions, emit a co-timestamped pair:
{role:user, style:interjection, content:<text>}
speech atom (role:assistant, style:None, tool_calls=[say(...)])
Both rows go in ``language_events`` at the same timestamp.
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
interjection timestamps to refresh the ``plan`` row at the same instant.
"""
from __future__ import annotations
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import InterjectionsConfig
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
from ..writer import speech_atom
@dataclass
class InterjectionsAndSpeechModule:
"""Generate task-start speech and mid-episode interjection/speech pairs."""
vlm: VlmClient
config: InterjectionsConfig
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
if record.frame_timestamps:
t0 = float(record.frame_timestamps[0])
initial = self._initial_speech(record)
if initial:
rows.append(speech_atom(t0, initial))
# Pull the ``plan`` module's subtask spans for this episode so the
# interjection prompt can ground itself in the actual current
# subtask at each chosen timestamp. The ``plan`` module ran first.
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
rows.extend(self._mid_episode_interjections(record, subtask_spans))
staging.write("interjections", rows)
@staticmethod
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
current: str | None = None
for span in spans:
if float(span["start"]) <= t:
current = span.get("text")
else:
break
return current
def _initial_speech(self, record: EpisodeRecord) -> str | None:
prompt = load_prompt("interjections_initial_speech").format(
episode_task=record.episode_task,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("text"), str):
text = result["text"].strip()
if text:
return text
return None
def _mid_episode_interjections(
self,
record: EpisodeRecord,
subtask_spans: Sequence[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Generate interjections aligned with the actual demo trajectory.
Teleop data is frozen — the robot already executed every step in
the video. A *counterfactual* interjection like "actually skip
the wipe" contradicts what then happens in the video, which is
what qwen36moe-10/11 surfaced as low-quality interjections.
Instead, anchor every interjection at a subtask boundary and
write it as a natural user request for the *upcoming* subtask.
The robot's visible next behavior IS the interjection's effect,
so the training signal stays consistent: interjection text →
plan refresh → action stream all line up.
"""
if self.config.max_interjections_per_episode <= 0:
return []
if len(subtask_spans) < 2:
# Need at least one transition (subtask 0 → subtask 1).
return []
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
# Boundaries: the start time of every subtask except the first
# (which is just t0 and is covered by the initial-task speech atom).
boundaries: list[tuple[float, str, str]] = []
for i in range(1, len(subtask_spans)):
ts = float(subtask_spans[i]["start"])
if ts < self.config.interjection_min_t:
continue
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
next_text = (subtask_spans[i].get("text") or "").strip()
if not next_text:
continue
boundaries.append((ts, prev_text, next_text))
if not boundaries:
return []
n = min(self.config.max_interjections_per_episode, len(boundaries))
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
out: list[dict[str, Any]] = []
for t, prev_subtask, next_subtask in chosen:
t_snap = snap_to_frame(t, record.frame_timestamps)
# Window straddles the boundary so the VLM sees the end of the
# previous subtask and the start of the next one — same
# conditioning the policy will see at training time.
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
prompt = load_prompt("interjections_interjection").format(
episode_task=record.episode_task,
prev_subtask=prev_subtask or "(starting from initial state)",
next_subtask=next_subtask,
timestamp=t_snap,
window_seconds=self.config.interjection_window_seconds,
)
images = self.frame_provider.frames_at(record, window_ts)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
continue
interjection_text = result.get("interjection")
speech_text = result.get("speech")
if not isinstance(interjection_text, str) or not interjection_text.strip():
continue
if not isinstance(speech_text, str) or not speech_text.strip():
continue
out.append(
{
"role": "user",
"content": interjection_text.strip(),
"style": "interjection",
"timestamp": t_snap,
"tool_calls": None,
}
)
out.append(speech_atom(t_snap, speech_text.strip()))
return out
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
"""Return a small set of frame timestamps centered on ``t_anchor``.
The window straddles the subtask boundary the interjection sits
on: roughly half the frames cover the end of the previous
subtask, half cover the start of the next one. The VLM therefore
sees BOTH what just finished AND what's about to start, which is
the conditioning we need to write a natural "now please do X"
request that matches the visible upcoming behavior.
"""
if not frame_timestamps:
return [t_anchor]
n = max(1, int(self.config.interjection_window_frames))
if n == 1:
return [t_anchor]
window = float(self.config.interjection_window_seconds)
step = window / max(1, n - 1)
# Center the window on the anchor so half lands before, half after.
start_offset = -window / 2.0
targets = [t_anchor + start_offset + step * i for i in range(n)]
first_ts = float(frame_timestamps[0])
last_ts = float(frame_timestamps[-1])
snapped: list[float] = []
seen: set[float] = set()
for tgt in targets:
clamped = min(last_ts, max(first_ts, tgt))
t = snap_to_frame(clamped, frame_timestamps)
if t not in seen:
seen.add(t)
snapped.append(t)
return snapped or [t_anchor]
@@ -0,0 +1,780 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
from __future__ import annotations
import logging
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import PlanConfig
from ..frames import (
FrameProvider,
null_provider,
to_contact_sheet_blocks,
)
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
logger = logging.getLogger(__name__)
# Prepended to every describe / segment prompt so the VLM knows the images are
# timestamped contact-sheet grids, not a single video, and reads the burned-in
# per-tile timestamp when choosing boundaries.
def _contact_sheet_preamble(columns: int) -> str:
return (
"CONTACT SHEETS — how to read the images below:\n"
f"- Each image is a grid of sampled video frames, {columns} per row, "
"with time running left-to-right then top-to-bottom (row-major).\n"
"- Each frame has its timestamp burned into the top-left corner, e.g. "
'"012.50s". Use that printed timestamp (not the tile position) when you '
"choose start/end times; boundaries should land on or near a printed "
"timestamp.\n"
"- Frames continue across grids: an action may span the end of one sheet "
"and the start of the next, so do not place a boundary just because a new "
"image begins.\n\n"
)
# Appended to every describe (and segment) prompt. A visual, causal definition
# of where one event ends and the next begins — adapted from macrodata/refiner —
# to sharpen cut points while the existing prompt keeps owning the imperative
# phrasing.
_CAUSAL_BOUNDARY_RULES = (
"EVENT BOUNDARIES — where one event ends and the next begins:\n"
"- Start a new event whenever the world state changes: an object becomes "
"held (the gripper closes on it), an object is released (the gripper opens "
"and it stays put), an object reaches a new location, a lid/door/drawer "
"changes open/closed state, a tool starts or stops affecting a surface, or "
"contents visibly move (e.g. poured).\n"
"- If a single action changes the same state gradually and continuously, "
"keep it as ONE event — do not split it.\n"
"- If the same action repeats on different objects or target locations, "
"treat each repetition as a separate event.\n"
"- Do NOT create boundaries for idle time, camera motion, hesitation, or "
"tiny hand adjustments."
)
@dataclass
class PlanSubtasksMemoryModule:
"""Generate subtask spans, plan, and memory rows.
All output is persistent (lives in ``language_persistent``):
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
(snapped to an exact frame).
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
timestamp via :meth:`run_plan_updates` (called by the executor after
the ``interjections`` module completes).
- ``memory`` rows: emitted at each subtask boundary (= subtask start
timestamp from the second subtask onward).
"""
vlm: VlmClient
config: PlanConfig
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
# Task driving every plan-module prompt: canonical episode_task, or a
# video-derived one when it's empty/placeholder (see derive_task_*).
effective_task = self._resolve_effective_task(record)
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
# free-form n_task_rephrasings; the effective task is always emitted
# first so the rotation covers the source-of-truth phrasing.
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
variants: list[str] | None = None
if self.config.task_aug_axes.enabled and effective_task:
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
elif self.config.n_task_rephrasings > 0 and effective_task:
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
if variants is not None:
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
subtask_spans = self._generate_subtasks(record, task=effective_task)
# subtask rows
for span in subtask_spans:
rows.append(
{
"role": "assistant",
"content": span["text"],
"style": "subtask",
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
"tool_calls": None,
}
)
# Plan rows at every subtask boundary (incl. t=0). The plan is a
# numbered list of still-todo subtasks, so re-emitting at each
# boundary makes it shrink as work progresses — ${plan} at frame t is
# exactly what's left to do.
if self.config.emit_plan:
for span in subtask_spans:
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
plan_text = self._generate_plan(
record, subtask_spans, refresh_t=boundary_t, task=effective_task
)
if plan_text is not None:
rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": float(boundary_t),
"tool_calls": None,
}
)
# memory rows at every subtask boundary except the very first start;
# skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only).
prior_memory = ""
memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else []
for i, span in memory_boundaries:
completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
if mem_text:
ts = snap_to_frame(span["start"], record.frame_timestamps)
rows.append(
{
"role": "assistant",
"content": mem_text,
"style": "memory",
"timestamp": ts,
"tool_calls": None,
}
)
prior_memory = mem_text
staging.write("plan", rows)
# ------------------------------------------------------------------
# Task derivation + rephrasings
# ------------------------------------------------------------------
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
{
"debug",
"test",
"tbd",
"todo",
"n/a",
"na",
"untitled",
"unnamed",
"default",
"placeholder",
}
)
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
"""Decide which task string drives the ``plan`` module for this episode.
Returns the user-supplied ``record.episode_task`` unless
``derive_task_from_video`` says otherwise (see config docstring).
Falls back gracefully to the canonical task if video derivation
fails.
"""
canonical = (record.episode_task or "").strip()
mode = (self.config.derive_task_from_video or "off").strip().lower()
if mode == "always":
derived = self._derive_task_from_video(record)
return derived or canonical
if mode == "if_short" and self._task_seems_bad(canonical):
derived = self._derive_task_from_video(record)
if derived:
return derived
return canonical
def _task_seems_bad(self, task: str) -> bool:
if not task:
return True
if len(task.split()) < int(self.config.derive_task_min_words):
return True
return task.lower() in self._PLACEHOLDER_TASKS
@staticmethod
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
seen: set[str] = set()
rows: list[dict[str, Any]] = []
for phrasing in phrasings:
key = phrasing.strip()
if not key or key in seen:
continue
seen.add(key)
rows.append(
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
)
return rows
# ------------------------------------------------------------------
# VLM call helpers — every plan-module prompt follows the same shape:
# build messages → single VLM call → pull a named field.
# ------------------------------------------------------------------
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
"""Run a single VLM call and return ``result[field]`` or ``None``.
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
dance every prompt-call site needs.
"""
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict):
return result.get(field)
return None
@staticmethod
def _text_message(text: str) -> list[dict[str, Any]]:
"""One-shot text-only user message wrapped for ``generate_json``."""
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
def _video_message(
self,
record: EpisodeRecord,
prompt: str,
window: tuple[float, float] | None = None,
) -> list[dict[str, Any]]:
"""User message combining the (optionally windowed) contact sheets with ``prompt``.
The prompt is always prefixed with a short explanation of how to read
the timestamped grids, so the model treats them as one ordered
sequence of frames rather than unrelated images.
"""
prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
"""Ask the VLM "what is this video about" with no task hint at all."""
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
return text.strip() if isinstance(text, str) and text.strip() else None
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
"""Generate ``n`` text-only paraphrases of ``base_task``."""
if n <= 0 or not base_task:
return []
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
if not isinstance(raw, list):
return []
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
return [s for s in out if s][:n]
# ------------------------------------------------------------------
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
# ------------------------------------------------------------------
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
"""One VLM call → variants along the 5-axis taxonomy.
Variants from all axes are flattened into a single list (the
downstream pipeline doesn't need to know about the per-axis
bucketing — every variant becomes a ``task_aug`` row). Order
is preserved for reproducibility: synonym_paraphrase first,
then omit_arm, then omit_orientation, then omit_grasp_method,
then combined_omissions.
"""
if not base_task:
return []
prompt = load_prompt("plan_task_aug_axes").format(
base_task=base_task,
n_synonym=axes_cfg.synonym_paraphrase,
n_omit_arm=axes_cfg.omit_arm,
n_omit_orientation=axes_cfg.omit_orientation,
n_omit_grasp_method=axes_cfg.omit_grasp_method,
n_combined=axes_cfg.combined_omissions,
)
result = self.vlm.generate_json([self._text_message(prompt)])[0]
if not isinstance(result, dict):
return []
ordered_axes = (
"synonym_paraphrase",
"omit_arm",
"omit_orientation",
"omit_grasp_method",
"combined_omissions",
)
flat: list[str] = []
seen: set[str] = set()
for axis in ordered_axes:
entries = result.get(axis)
if not isinstance(entries, list):
continue
for item in entries:
if not isinstance(item, str):
continue
key = item.strip().strip('"').strip("'")
if not key or key in seen:
continue
seen.add(key)
flat.append(key)
return flat
def _episode_video_block(
self, record: EpisodeRecord, window: tuple[float, float] | None = None
) -> list[dict[str, Any]]:
"""Timestamped contact sheets for the describe / segmentation prompts.
Always renders the (optionally windowed) episode as contact sheets:
frames sampled at ``frames_per_second`` and packed into timestamped
JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole
episodes that exceed it are windowed upstream in
:meth:`_generate_subtasks` so each call stays within budget while the
full episode keeps its sampling density.
When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE
(``ts - w0``) to match the window-relative time frame the
segmentation prompt works in (spans are offset back to absolute time
afterwards).
"""
if not record.frame_timestamps:
return []
if window is not None:
w0, w1 = float(window[0]), float(window[1])
dur = max(0.0, w1 - w0)
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_frames_per_prompt)
if n <= 1 or dur <= 0.0:
timestamps = [0.5 * (w0 + w1)]
else:
step = dur / (n - 1)
timestamps = [w0 + i * step for i in range(n)]
frames = self.frame_provider.frames_at(record, timestamps)
rel = [ts - w0 for ts in timestamps[: len(frames)]]
return self._contact_sheet_blocks(frames, rel)
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_frames_per_prompt)
timestamps = self._uniform_episode_timestamps(record, n)
frames = self.frame_provider.frames_at(record, timestamps)
return self._contact_sheet_blocks(frames, timestamps[: len(frames)])
@staticmethod
def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]:
"""``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly."""
ts = record.frame_timestamps
if n >= len(ts):
return [float(t) for t in ts]
t0, t_last = float(ts[0]), float(ts[-1])
if t_last <= t0 or n <= 1:
return [t0] * max(1, n)
step = (t_last - t0) / (n - 1)
return [t0 + i * step for i in range(n)]
def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]:
"""Build timestamped contact-sheet image blocks from decoded frames."""
return to_contact_sheet_blocks(
frames,
timestamps,
columns=self.config.contact_sheet_columns,
frames_per_sheet=self.config.contact_sheet_frames_per_sheet,
frame_width=self.config.contact_sheet_frame_width,
quality=self.config.contact_sheet_quality,
)
def run_plan_updates(
self,
record: EpisodeRecord,
staging: EpisodeStaging,
interjection_times: Sequence[float],
interjection_texts: Sequence[str] | None = None,
) -> None:
"""Append additional ``plan`` rows at every interjection timestamp.
Plans refresh ONLY on user interjections (event-driven). The
interjection text is forwarded into the prompt so the refreshed plan
reflects the user's correction.
"""
if not self.config.emit_plan:
return
existing = staging.read("plan")
# Pass the last frame timestamp so the final span is closed (else its
# end == start, zero duration, and a refresh inside it is missed).
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
new_rows = list(existing)
texts: list[str | None] = (
[None] * len(interjection_times)
if interjection_texts is None
else [str(t) if t else None for t in interjection_texts]
)
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
t = snap_to_frame(raw_t, record.frame_timestamps)
if t in already_planned:
continue
already_planned.add(t)
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
if plan_text is not None:
new_rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": t,
"tool_calls": None,
}
)
staging.write("plan", new_rows)
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
"""Generate subtask spans, optionally via a multi-call quality chain.
Single call (default): watch video → emit subtask JSON.
Multi-call (opt-in, higher quality, more VLM calls):
1. ``subtask_describe_first`` — a grounding pass that narrates
ONLY what is visible (no JSON commitment to subtasks yet);
its description is injected into the segmentation prompt so
the model segments its own grounded observations instead of
pattern-matching the task text.
2. segmentation — emit subtask JSON (as before).
"""
if record.row_count == 0 or not record.frame_timestamps:
return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
effective_task = task if task is not None else record.episode_task
# ---- Auto-windowing (keeps the full sampling density) --------
# Contact sheets are cheap, but a whole long episode sampled at
# ``frames_per_second`` can still exceed ``max_frames_per_prompt``.
# When it does, split into consecutive windows of exactly that many
# frames (one describe→segment call each, still at the full sampling
# density), then merge + stitch — so an episode of any length is
# covered at full density rather than subsampled into one sparse call.
fps = max(1e-6, float(self.config.frames_per_second))
n_whole = int(round(episode_duration * fps)) + 1
if n_whole > self.config.max_frames_per_prompt:
window_s = self.config.max_frames_per_prompt / fps
return self._generate_subtasks_windowed(record, effective_task, window_s)
# ---- Pass 1 (optional): grounding description ----------------
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
description = self._describe_episode(record, effective_task)
if description:
observation_block = (
"You watched this video and described, chronologically, "
"ONLY what the robot actually does:\n"
f'"""{description}"""\n\n'
"Segment THAT grounded description (cross-checked against "
"the video) into atomic subtasks. Do not introduce any "
"action that is not in your description above.\n\n"
)
# ---- Pass 2: segmentation ------------------------------------
prompt = self._with_causal_rules(
load_prompt("plan_subtasks").format(
episode_task=effective_task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
observation_block=observation_block,
)
)
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
cleaned = self._clean_spans(spans, record)
if not cleaned:
return []
# ---- Full-episode coverage stitch ----------------------------
# The VLM can start after t0 or leave gaps, so frames fall through
# with no active subtask. Always stitch into a contiguous
# [t0, t_last] cover.
cleaned = self._stitch_full_coverage(cleaned, record)
return cleaned
def _generate_subtasks_windowed(
self, record: EpisodeRecord, task: str, window_s: float
) -> list[dict[str, Any]]:
"""Subtask generation in fixed-length windows at constant fps.
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
seconds, runs the describe -> segment chain on each window's own
frames (sampled at ``frames_per_second``), offsets
each window's spans back to absolute episode time, then merges +
stitches into a contiguous whole-episode cover.
"""
t0 = float(record.frame_timestamps[0])
t_last = float(record.frame_timestamps[-1])
all_spans: list[dict[str, Any]] = []
w0 = t0
n_windows = 0
while w0 < t_last - 1e-6:
w1 = min(w0 + window_s, t_last)
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
n_windows += 1
w0 = w1
logger.info(
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
record.episode_index,
n_windows,
window_s,
len(all_spans),
)
# Merge across windows: clamp to the absolute episode, sort, and
# frame-snap to distinct starts (handles any boundary collisions).
cleaned = self._clean_spans(all_spans, record)
if not cleaned:
return []
return self._stitch_full_coverage(cleaned, record)
def _subtasks_for_window(
self, record: EpisodeRecord, task: str, w0: float, w1: float
) -> list[dict[str, Any]]:
"""Run describe -> segment on one ``[w0, w1]`` window.
The model works in window-RELATIVE time ``[0, L]`` (it perceives
the window as a clip starting at 0); spans are offset back to
absolute ``[w0, w1]`` before returning.
"""
window = (w0, w1)
win_len = max(0.0, w1 - w0)
observation_block = ""
if getattr(self.config, "subtask_describe_first", False):
description = self._describe_episode(record, task, window=window)
if description:
observation_block = (
"You watched this video clip and described, chronologically, "
"ONLY what the robot actually does:\n"
f'"""{description}"""\n\n'
"Segment THAT grounded description (cross-checked against "
"the clip) into atomic subtasks. Do not introduce any "
"action that is not in your description above.\n\n"
)
prompt = self._with_causal_rules(
load_prompt("plan_subtasks").format(
episode_task=task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{win_len:.3f}",
observation_block=observation_block,
)
)
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
# Window-relative clamp; no frame-snap dedupe yet (done on the
# merged absolute set).
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
if not cleaned:
return []
# Offset window-relative spans back to absolute episode time.
for s in cleaned:
s["start"] = w0 + float(s["start"])
s["end"] = w0 + float(s["end"])
return cleaned
def _stitch_full_coverage(
self, spans: list[dict[str, Any]], record: EpisodeRecord
) -> list[dict[str, Any]]:
"""Make subtask spans tile the full episode with no gaps.
* The first subtask starts at the episode's first frame ``t0``
(any idle / approach before the first labelled action is folded
into it), so every early frame has an active subtask.
* Each subtask's ``end`` is snapped to the next subtask's
``start`` (gaps between spans are closed), and the final
subtask's ``end`` extends to the last frame ``t_last``.
Starts are otherwise left as the (already frame-snapped, distinct)
values the VLM produced — only the FIRST start is pulled
back to ``t0``, which can't collide with a later span because it
was already the earliest. Purely deterministic; runs after the
VLM passes.
"""
if not spans or not record.frame_timestamps:
return spans
t0 = float(record.frame_timestamps[0])
t_last = float(record.frame_timestamps[-1])
spans = sorted(spans, key=lambda s: float(s["start"]))
spans[0]["start"] = t0
for i in range(len(spans) - 1):
spans[i]["end"] = float(spans[i + 1]["start"])
spans[-1]["end"] = t_last
for s in spans:
if float(s["end"]) < float(s["start"]):
s["end"] = float(s["start"])
return spans
@staticmethod
def _with_causal_rules(prompt: str) -> str:
"""Append the causal event-boundary rules to a describe/segment prompt."""
return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}"
def _clean_spans(
self,
spans: Any,
record: EpisodeRecord,
bounds: tuple[float, float] | None = None,
dedupe: bool = True,
) -> list[dict[str, Any]]:
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
``bounds`` overrides the clamp range — pass the window's
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
``None`` to clamp to the whole episode ``[t0, t_last]``.
``dedupe`` runs the frame-snap distinct-start step; skip it for
window-relative spans (frame snapping is done once on the merged,
absolute-time set).
"""
if not spans:
return []
if bounds is not None:
lo, hi = float(bounds[0]), float(bounds[1])
else:
lo = record.frame_timestamps[0]
hi = record.frame_timestamps[-1]
cleaned: list[dict[str, Any]] = []
for span in spans:
try:
start = float(span["start"])
end = float(span["end"])
text = str(span["text"]).strip()
except (KeyError, ValueError, TypeError):
continue
start = max(lo, min(start, hi))
end = max(lo, min(end, hi))
if end < start:
start, end = end, start
if not text:
continue
cleaned.append({"text": text, "start": start, "end": end})
cleaned.sort(key=lambda s: s["start"])
if dedupe:
return self._dedupe_starts_to_distinct_frames(cleaned, record)
return cleaned
def _describe_episode(
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
) -> str:
"""Grounding pass: free-form chronological description of the (windowed) video."""
prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task))
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
return text.strip() if isinstance(text, str) and text.strip() else ""
@staticmethod
def _dedupe_starts_to_distinct_frames(
spans: list[dict[str, Any]], record: EpisodeRecord
) -> list[dict[str, Any]]:
"""Bump same-frame subtask starts onto distinct frames.
Two consecutive VLM spans whose ``start`` rounds to the same
source frame (after :func:`snap_to_frame`) would otherwise emit
two ``style=subtask`` rows at the identical persistent
timestamp. The training-time renderer's ``active_at(t,
style=subtask)`` resolver can't disambiguate that and raises
``Ambiguous resolver for style='subtask'``.
Walk the (sorted-by-start) spans, snap each to its frame, and
if the snapped frame is already taken push the span onto the
next unused frame so both subtasks survive on distinct
timestamps. If the episode ends before a free frame is found,
the trailing span is dropped with a warning — better than
poisoning the render.
"""
if not spans:
return spans
frames = record.frame_timestamps
if not frames:
return spans
used: set[float] = set()
out: list[dict[str, Any]] = []
for span in spans:
ts = snap_to_frame(span["start"], frames)
if ts in used:
next_ts = next((f for f in frames if f > ts and f not in used), None)
if next_ts is None:
logger.warning(
"episode %d: subtask %r snapped to occupied frame "
"%.3f and no free later frame exists — dropping",
record.episode_index,
span.get("text"),
ts,
)
continue
ts = next_ts
used.add(ts)
new_span = {**span, "start": ts}
if float(new_span.get("end", ts)) < ts:
new_span["end"] = ts
out.append(new_span)
return out
def _generate_plan(
self,
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
subtask_spans: Sequence[dict[str, Any]],
*,
refresh_t: float | None = None,
interjection: str | None = None, # noqa: ARG002
task: str | None = None, # noqa: ARG002
) -> str | None:
"""Deterministic plan = numbered list of *still-todo* subtasks.
No VLM call: a plain numbered list keeps the plan aligned with the
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
cost a round-trip per episode/refresh and could diverge).
1. <subtask 1>
2. <subtask 2>
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
interjections, and ``run_episode`` at each boundary), only subtasks
starting at or after ``refresh_t`` are included — so it always
describes what's left.
"""
if not subtask_spans:
return None
remaining = [
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
]
if not remaining:
# Past the last subtask boundary on a late refresh — nothing
# left to plan; emit None so the caller skips the row.
return None
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
def _generate_memory(
self,
record: EpisodeRecord,
prior_memory: str,
completed: str,
remaining: Sequence[str],
*,
task: str | None = None,
) -> str:
prompt = load_prompt("plan_memory").format(
episode_task=(task if task is not None else record.episode_task),
prior_memory=prior_memory or "(none)",
completed_subtask=completed,
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
)
memory = self._vlm_field(self._text_message(prompt), "memory")
return memory.strip() if isinstance(memory, str) else ""
@@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompt templates loaded as plain text.
One file per use site. Templates use ``str.format(**vars)`` substitution; we
intentionally avoid jinja2 here so the templates remain inspectable in
plain editors and roundtrip cleanly through ``ruff format``.
"""
from __future__ import annotations
from pathlib import Path
_DIR = Path(__file__).parent
def load(name: str) -> str:
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
path = _DIR / f"{name}.txt"
return path.read_text(encoding="utf-8")
@@ -0,0 +1,12 @@
The user just asked the robot: "{episode_task}".
Generate a short verbal acknowledgement the robot would speak back before
beginning the task. Style: compact, confident, friendly.
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
"OK, starting with the sponge.", "Got it.".
Prefer very short replies: "Got it.", "On it.", "OK."
Output strictly valid JSON:
{{ "text": "<the spoken acknowledgement>" }}
@@ -0,0 +1,46 @@
You are generating training data for a Hi Robot-style hierarchical
robot policy. The robot in this demonstration has ALREADY executed
every step shown in the video — we cannot retroactively change the
action stream. To keep training data consistent with the video, the
"interjection" must align with what the robot is *about to do next* in
the demonstration, framed as a natural mid-task user request.
The episode's overall task: "{episode_task}".
The images above show roughly {window_seconds:.1f} seconds straddling a
subtask boundary in the demonstration:
- Subtask the robot just finished: "{prev_subtask}"
- Subtask the robot is about to start: "{next_subtask}"
- Time into episode: {timestamp:.2f}s
Write ONE compact interjection the user would naturally say at this
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
Keep it like a mid-task coaching cue, not a full instruction paragraph.
Also write the robot's compact verbal acknowledgement.
Hard rules:
- The interjection MUST be consistent with the next subtask. The user
cannot ask for something different from what the robot then does in
the video. If you're tempted to say "actually skip X" or "do Y
instead", DO NOT — those would contradict the demonstration.
- The interjection must reference an object, location, or action that
is plausible given the visible scene and the next subtask text.
- One short phrase or sentence each. Conversational, not robotic.
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
- Keep robot speech very short: "OK.", "On it.", "Doing that."
Style examples (vary the phrasing — don't reuse these verbatim):
- "Now go ahead and {next_subtask}."
- "Great, can you {next_subtask} next?"
- "{next_subtask}, please."
- "Before you continue, please {next_subtask}."
- "Looking good — {next_subtask} now."
- "Okay, {next_subtask}."
Output strictly valid JSON:
{{
"interjection": "<short cue from the user, asking for the next subtask>",
"speech": "<short robot acknowledgement>"
}}
@@ -0,0 +1,36 @@
You are updating the robot's compressed semantic memory at the boundary of
a completed subtask.
Reference (verbatim from MEM, Torne 2026):
"Remove or compress information in the language memory whenever
appropriate. Keep ONLY the minimal set of relevant information for future
task execution. Specific object attributes (colors, precise quantities of
each item) get discarded when their details won't affect subsequent
actions. Functional outcomes (where items went, how many) are preserved."
Episode task: "{episode_task}"
Previous memory: {prior_memory}
Just-completed subtask: "{completed_subtask}"
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
robot has accomplished so far — the running story it would tell itself.
Authoring rules:
- First person, past tense. Every sentence starts with "I": "I picked
up...", "I opened...", "I moved to...".
- One or two short sentences. Extend the previous memory with the
just-completed subtask; do not rewrite it from scratch.
- Keep WHAT happened (functional outcomes — where items went, how many),
drop HOW (grasp details, motions).
- Compress completed steps and drop object attributes (colors, exact
counts) once they no longer affect the remaining subtasks.
Example (MEM, Torne 2026):
Before: "I prepared the pot and got the potatoes, milk, and butter. I
moved to the drawer."
After: "I prepared the pot and got the ingredients. I opened the
drawer with the masher."
Output strictly valid JSON:
{{ "memory": "<one or two short first-person past-tense sentences>" }}
@@ -0,0 +1,27 @@
You are watching a teleoperated robot demonstration from a single
camera. The user asked the robot to: "{episode_task}"
This is an OBSERVATION pass. Watch the entire clip and describe, in
chronological order, ONLY what the robot physically does — the concrete
motions, approaches, contacts, grasps, releases, and relocations you can
actually SEE in the frames.
Hard rules:
- Describe only motion visible in the video. Do NOT use the task
instruction to guess steps that aren't shown. The instruction is the
goal; the video is ground truth.
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
the single field below. Just narrate what happens.
- Give an approximate timestamp (in seconds) for each distinct event,
e.g. "0.0-1.4s: the base drives forward toward the stove".
- Do NOT invent objects, grasps, destinations, or steps. If the robot
only does one thing (e.g. it just navigates and the clip ends), say
exactly that and nothing more.
- Be concrete and literal. "the gripper closes on the mug" — not "the
robot prepares to make coffee".
Output strictly valid JSON:
{{
"description": "<chronological, timestamped description of ONLY what is visible>"
}}
@@ -0,0 +1,112 @@
You are labeling a teleoperated robot demonstration.
The user originally asked: "{episode_task}"
You are shown the entire demonstration as a single video. Watch the
whole clip, then segment it into a list of consecutive atomic subtasks
the robot performs.
{observation_block}GROUNDING — read this first, it overrides everything below:
- Label ONLY what the robot actually does in the video. Every subtask
you emit must correspond to motion you can SEE in specific frames.
- Do NOT invent, anticipate, or pad. If the robot only does one thing
(e.g. it just navigates to a location and the clip ends), emit
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
subtasks than the ceiling is not just allowed, it is expected for
short / atomic demonstrations. One correct subtask is far better
than several invented ones.
- If the video does not clearly show the action implied by the task,
describe what you actually see — do NOT fabricate the task's steps
from the instruction text. The instruction tells you the goal; the
VIDEO is the ground truth for what happened.
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
- Each subtask = one COMPOSITE atomic skill the low-level policy can
execute end-to-end. A "skill" bundles its own approach motion with
its terminal action — do NOT split the approach off as its own
subtask. The whole-arm policy already learns to reach as part of
every manipulation primitive.
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
these verbs (extend only when none fits):
pick up <obj> — approach + grasp + lift in one subtask
put <obj> on/in <loc> — transport + release in one subtask
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
push <obj> — contact + linear shove
pull <obj> — contact + linear retract
turn <knob/dial/handle> — rotary actuation
press <button> — single-press contact
open <drawer/door/lid> — full open motion
close <drawer/door/lid> — full close motion
pour <src> into <dst> — tilt + flow
insert <obj> into <slot>— alignment + push-fit
go to <loc> — ONLY when no grasp / actuation follows
(e.g. a pure relocation between phases).
If the next subtask grasps something at
that location, drop "go to ..." and just
write "pick up ..." instead.
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
as standalone subtasks; fold them into the parent composite:
"move to X" → fold into "pick up X" (or whatever follows)
"reach for X" → fold into "pick up X"
"grasp X" → fold into "pick up X"
"lift X" → fold into "pick up X" (or "put X on Y" if it's
the transport phase of a place)
"release X" → fold into "put X on Y" (or "place X in Y")
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
detail (which hand, which grasp point) ONLY when it is needed to
disambiguate. Every subtask must begin with one of the verbs
above (no leading nouns, no "then", no "first").
- NEVER use third person. Never write "the robot", "the arm", "the
gripper moves", "it picks up" — the robot is implied. Command it,
do not describe it.
- Use the exact object nouns from the task above. If the task says
"cube", every subtask says "cube" — never switch to "block". If it
says "box", never switch to "bin"/"container". Keep vocabulary
consistent across the whole episode.
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
"turn red knob", "press start button", "go to sink".
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
must be folded into "pick up blue cube"); "the robot arm moves
towards the blue cube" (third person, too long); "carefully pick
up the cube" (adverb, article); "release the yellow block"
("block" when the task said "cube", and "release" must be folded
into a "put"/"place" subtask).
- Subtasks are non-overlapping and cover the full episode in order.
Choose the cut points yourself based on what you see in the video
(gripper open/close events, contact, regrasps, transitions).
- Each subtask spans at least {min_subtask_seconds} seconds. If a
candidate span would be shorter, merge it into its neighbour
rather than emitting it.
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
are preferred over many micro-steps.
- Every subtask's [start_time, end_time] must lie within
[0.0, {episode_duration}] seconds.
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
fires ONLY on the spatial situation it names; it must not change how you
label any other situation):
- STACK vs PUT: if an object is placed ON TOP OF another specific object
(not on a flat table / shelf / counter), use "stack ... on ...", not
"put". "stack blue book on green book", NOT "put blue book on table".
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
receptacle (push-fit), use "insert ... into ...", not "put".
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
on the object and the object moves WITH the hand, it is "pick up" /
"retrieve" (object leaves its location). If the gripper OPENS and the
object stays where the hand left it, it is "put" / "place" (object
arrives at a location). Decide by which way the object moves, not by
where the hand ends up.
- POUR vs PUT: only use "pour" when the source is tilted and contents
flow out; moving a full container without tilting is "put"/"place".
Output strictly valid JSON of shape:
{{
"subtasks": [
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
...
]
}}
@@ -0,0 +1,67 @@
You are generating structured augmentations of a robot task instruction
for training a language-conditioned policy. Unlike free-form rephrasing,
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
a specific element of the task while preserving its meaning.
Original task: "{base_task}"
Produce variants along five named axes. Each axis has a target count.
The whole batch should expose the policy to maximum linguistic diversity
WITHOUT changing what the robot is supposed to do.
Axes and target counts:
synonym_paraphrase ({n_synonym}):
Different wording / verbs / sentence structure. ALL information
from the original task is preserved — same object, same arm
specification if present, same orientation if present, same grasp
if present.
omit_arm ({n_omit_arm}):
Drop the left/right/both arm specification from the task. Skip
entirely (emit 0 entries) if the original task does NOT mention an
arm. Do not invent an arm specification just to omit it.
omit_orientation ({n_omit_orientation}):
Drop orientation cues (upright, sideways, facing the user,
long-edge-first, etc.). Skip entirely if no orientation cue is
present in the original task.
omit_grasp_method ({n_omit_grasp_method}):
Drop the grip / grasp method specification (pinch, wrap, hold by
the rim, etc.). Skip entirely if no grasp method is mentioned.
combined_omissions ({n_combined}):
Combine TWO of the above omissions simultaneously (e.g. drop both
arm and orientation). Skip entirely if fewer than two of (arm,
orientation, grasp_method) appear in the original task.
Hard rules:
- Each variant MUST preserve the core action, the target object, AND
the goal / destination. Do not change which object is involved, where
it goes, or the high-level action. "Navigate to the stove" may become
"go to the stove" or "head over to the stove" — it must NEVER become
"wander around the kitchen", "explore the room", or anything that
drops or generalises the stove destination. If you cannot vary the
wording without changing the goal, emit fewer variants.
- Only the FIVE listed elements (wording, arm, orientation, grasp
method, or a combination) may be varied or omitted. The verb's
meaning, the object, and the destination are fixed.
- Each variant is plain prose, no markdown, no quotes, no list numbers.
- Each variant must be DISTINCT from every other variant in the entire
output, both within and across axes. Near-duplicates are not allowed.
- If an axis cannot reach its target count because the original task
lacks the omittable element, emit fewer entries — do NOT pad the
axis with paraphrases that belong to a different axis.
- Variants should not all start with verbs — vary sentence structure
(some imperative, some polite request, some question).
Output strictly valid JSON of shape:
{{
"synonym_paraphrase": ["<v1>", "<v2>", ...],
"omit_arm": ["<v1>", "<v2>", ...],
"omit_orientation": ["<v1>", ...],
"omit_grasp_method": ["<v1>", ...],
"combined_omissions": ["<v1>", ...]
}}
@@ -0,0 +1,32 @@
You are generating training data for a Hi Robot-style policy. We need
{n} alternative phrasings of the same robot task so the policy sees
diverse user prompts during training instead of the same canonical
string repeated every frame.
Original task:
"{base_task}"
Generate exactly {n} alternative phrasings of the same task. Vary:
- formality (casual / polite / curt)
- verbosity (mostly short imperative; occasional polite request)
- word choice (synonyms, different verbs)
- sentence structure (imperative / question / suggestion)
Hard rules:
- Each phrasing MUST preserve the exact meaning of the original task.
Do not change which object is involved, the destination, or the
action. Do not add extra steps. Do not invent new objects.
- Each phrasing must be a short phrase or sentence, plain prose, no
markdown, no quotes, no list numbers.
- Phrasings must be distinct — no near-duplicates.
- Output exactly {n} entries.
Output strictly valid JSON:
{{
"rephrasings": [
"<phrasing 1>",
"<phrasing 2>",
...
]
}}
@@ -0,0 +1,17 @@
The video above shows a robot manipulation episode in full. Look at
the entire video and describe in ONE concise sentence what the robot
is doing.
Rules:
- One sentence, in natural English, like a user instruction.
- Capture the goal of the demonstration, not low-level motions.
Example: "place the yellow cube into the red bin" — not "move the
end-effector down 5cm and close the gripper".
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
- Do not invent objects or actions that aren't visible.
- Do not output anything other than the JSON object below.
Output strictly valid JSON:
{{
"task": "<single concise sentence describing what the robot does in this video>"
}}
@@ -0,0 +1,32 @@
You are generating a frame-grounded visual question/answer pair for
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
Policies — both train policies on grounded features such as bounding box
pixel coordinates, keypoints, counts, attributes, and spatial relations.
The frame shows a robot working on: "{episode_task}".
Question types and the EXACT answer JSON shape required for each:
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}}, ...]}}
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
ECoT example: "a white cup [124, 25, 176, 113]".
keypoint => {{"label": "<point>", "point_format": "xy",
"point": [x, y]}}
count => {{"label": "<obj>", "count": <int>,
"note": "<optional short note>"}}
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
"value": "<observed value>"}}
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
"above|below|near>", "object": "<obj>"}}
Generate a question of type "{question_type}". Output strictly valid JSON:
{{
"question": "<short, frame-grounded question>",
"answer": <object whose shape matches the schema above>
}}
@@ -0,0 +1,216 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datatrove-shaped reader.
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
episode containing:
- ``episode_index``: int
- ``frame_timestamps``: tuple[float, ...]
- ``frame_indices``: tuple[int, ...]
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
- ``data_path``: pathlib.Path of the source parquet shard
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
This shape lets each module operate per-episode without loading all parquet
rows into memory at once.
"""
from __future__ import annotations
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import load_tasks
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
@dataclass
class EpisodeRecord:
"""Per-episode record yielded by the reader."""
episode_index: int
episode_task: str
frame_timestamps: tuple[float, ...]
frame_indices: tuple[int, ...]
data_path: Path
row_offset: int # row offset within the parquet file where this episode starts
row_count: int # number of rows for this episode
# Memoized parquet slice — populated on first ``frames_df()`` call so
# repeat queries from different modules don't re-read the whole shard.
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
def frames_df(self): # type: ignore[no-untyped-def]
"""Lazy-load the pandas slice for this episode (memoized)."""
if self._frames_df_cache is None:
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
table = pq.read_table(self.data_path)
df: pd.DataFrame = table.to_pandas()
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
drop=True
)
return self._frames_df_cache
def reconstruct_subtask_spans(
rows: Sequence[dict[str, Any]],
*,
episode_end_t: float | None = None,
) -> list[dict[str, Any]]:
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
Each span's ``end`` is the next span's ``start``. The final span's
``end`` defaults to its own ``start`` (zero-duration) — pass
``episode_end_t`` to extend it to the episode's last frame instead,
which is what downstream consumers (memory, interjection boundary
selection) expect.
Used by the ``plan`` module (plan-update pass) and the
``interjections`` module (interjection anchoring), which both need the
same span shape.
"""
sorted_rows = sorted(
(r for r in rows if r.get("style") == "subtask"),
key=lambda r: float(r["timestamp"]),
)
spans: list[dict[str, Any]] = []
for r in sorted_rows:
t = float(r["timestamp"])
if spans:
spans[-1]["end"] = t
spans.append({"text": r.get("content") or "", "start": t, "end": t})
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
spans[-1]["end"] = float(episode_end_t)
return spans
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
"""Snap an arbitrary float to the nearest exact source frame timestamp.
Modules use this when emitting event-style rows so the row's
timestamp matches a real parquet frame: event rows must land on an
exact frame, otherwise the per-frame event lookup the writer does
would never match them.
"""
if not frame_timestamps:
return float(t)
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
return float(nearest)
def _load_tasks_lookup(root: Path) -> dict[int, str]:
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
Returns an empty dict when the file is absent — the task description is
derived later from the video if needed. Reuses the library-level
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
frame indexed by task string with a ``task_index`` column.
"""
if not (root / DEFAULT_TASKS_PATH).exists():
return {}
tasks = load_tasks(root)
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
Episodes are yielded in ascending ``episode_index`` order. The reader does
not assume a specific chunk/file layout: it scans every ``*.parquet``
under ``data/`` and groups by ``episode_index``.
"""
tasks = _load_tasks_lookup(root)
data_dir = root / "data"
parquet_files = sorted(data_dir.rglob("*.parquet"))
only_set = set(only_episodes) if only_episodes is not None else None
for path in parquet_files:
yield from _iter_one_path(path, tasks, only_set)
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
table = pq.read_table(path)
names = table.column_names
if "episode_index" not in names:
return
episode_col = table.column("episode_index").to_pylist()
timestamp_col = (
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
)
frame_col = (
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
)
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
def _build(
ep: int,
start: int,
end: int,
task_idx: int | None,
ts_buf: list[float],
fi_buf: list[int],
) -> EpisodeRecord | None:
if only_set is not None and ep not in only_set:
return None
task = tasks.get(task_idx, "") if task_idx is not None else ""
return EpisodeRecord(
episode_index=ep,
episode_task=task,
frame_timestamps=tuple(ts_buf),
frame_indices=tuple(fi_buf),
data_path=path,
row_offset=start,
row_count=end - start,
)
cur_ep: int | None = None
start_offset = 0
ts_buf: list[float] = []
fi_buf: list[int] = []
cur_task_idx: int | None = None
for i, ep in enumerate(episode_col):
if cur_ep is None:
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
continue
if ep != cur_ep:
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
else:
ts_buf.append(timestamp_col[i])
fi_buf.append(frame_col[i])
if cur_ep is not None:
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Per-episode staging.
Each module writes its raw output as a JSONL file under
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
staging tree and partitions rows into the two language columns.
JSONL is preferred over parquet here because the staging artifact is meant to
be human-inspectable, easy to diff between prompt iterations, and trivially
appended to. The final dataset format is parquet; staging is just an
intermediate.
"""
from __future__ import annotations
import json
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
ModuleName = str
_MODULES: tuple[ModuleName, ...] = (
"plan",
"interjections",
"vqa",
)
@dataclass
class EpisodeStaging:
"""Filesystem layout for a single episode's staged module outputs."""
root: Path
episode_index: int
@property
def episode_dir(self) -> Path:
return self.root / f"episode_{self.episode_index:06d}"
def path_for(self, module: ModuleName) -> Path:
if module not in _MODULES:
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
return self.episode_dir / f"{module}.jsonl"
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
path = self.path_for(module)
path.parent.mkdir(parents=True, exist_ok=True)
# Atomic replace: a crash mid-write would otherwise leave a
# half-written JSONL file that ``read()`` would then fail to
# parse. Write to a sibling .tmp and rename so the target path
# only ever points at a complete file.
tmp_path = path.with_suffix(path.suffix + ".tmp")
with tmp_path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
f.write("\n")
tmp_path.replace(path)
return path
def read(self, module: ModuleName) -> list[dict[str, Any]]:
path = self.path_for(module)
if not path.exists():
return []
out: list[dict[str, Any]] = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
out.append(json.loads(line))
return out
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
return {m: self.read(m) for m in _MODULES}
def has(self, module: ModuleName) -> bool:
return self.path_for(module).exists()
@@ -0,0 +1,332 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre-write validation against staged outputs.
Runs after all three modules have written their per-episode artifacts but
*before* the writer rewrites parquet shards. The validator never touches
parquet; it only inspects the staging tree and the source frame timestamps
exposed by :class:`EpisodeRecord`.
Checks (per the plan's "Intermediate staging and validation" section):
- exact timestamp alignment against source frame timestamps
- no orphan speech / interjection pairs
- plan / memory emission consistency (events have a paired persistent row)
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
attribute / spatial)
- every row maps to its correct column under :func:`column_for_style`
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import (
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
column_for_style,
is_view_dependent_style,
validate_camera_field,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
@dataclass
class ValidationReport:
"""Outcome of one validation pass across all episodes."""
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
episodes_checked: int = 0
@property
def ok(self) -> bool:
return not self.errors
def add_error(self, message: str) -> None:
self.errors.append(message)
def add_warning(self, message: str) -> None:
self.warnings.append(message)
def summary(self) -> str:
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
def classify_vqa_answer(payload: Any) -> str | None:
"""Best-effort classification of a VQA answer payload to a question type."""
if not isinstance(payload, dict):
return None
keys = set(payload.keys())
for kind, required in VQA_ANSWER_SHAPES.items():
if required.issubset(keys):
return kind
return None
@dataclass
class StagingValidator:
"""Walks the staging tree and produces a :class:`ValidationReport`."""
timestamp_atol: float = 0.0 # exact-match by default
dataset_camera_keys: tuple[str, ...] | None = None
"""Known ``observation.images.*`` keys on the dataset. When set, the
validator additionally enforces that every view-dependent row's
``camera`` field references one of these keys. Pass ``None`` (default)
to skip that cross-check (e.g. in unit tests with no real dataset)."""
def validate(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
) -> ValidationReport:
report = ValidationReport()
for record in records:
self._validate_episode(record, staging_dir, report)
report.episodes_checked += 1
return report
def _validate_episode(
self,
record: EpisodeRecord,
staging_dir: Path,
report: ValidationReport,
) -> None:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged = staging.read_all()
all_rows: list[dict[str, Any]] = []
for module_name, rows in staged.items():
for row in rows:
row = {**row, "_module": module_name}
all_rows.append(row)
frame_ts = set(record.frame_timestamps)
events: list[dict[str, Any]] = []
persistent: list[dict[str, Any]] = []
for row in all_rows:
self._check_column_routing(row, report, record.episode_index)
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
# ``_check_column_routing`` already recorded any unknown-style error;
# don't let the same ``column_for_style`` lookup raise here uncaught.
try:
column = column_for_style(row.get("style"))
except ValueError:
continue
if column == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
events.append(row)
for row in events:
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
self._check_speech_interjection_pairs(events, report, record.episode_index)
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
self._check_vqa_json(events, report, record.episode_index)
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
def _check_camera_field(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
dataset_camera_keys: Sequence[str] | None,
) -> None:
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
style = row.get("style")
camera = row.get("camera")
try:
validate_camera_field(style, camera)
except ValueError as exc:
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
return
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
report.add_error(
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
)
def _check_vqa_uniqueness_per_frame_camera(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
counts: dict[tuple[float, str, str], int] = {}
for row in events:
if row.get("style") != "vqa":
continue
ts = row.get("timestamp")
camera = row.get("camera")
role = row.get("role")
if ts is None or camera is None or role is None:
continue # other validators flag these
key = (float(ts), str(camera), str(role))
counts[key] = counts.get(key, 0) + 1
for (ts, camera, role), n in counts.items():
if n > 1:
report.add_error(
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
)
def _check_column_routing(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
) -> None:
style = row.get("style")
module = row.get("_module")
try:
target_col = column_for_style(style)
except ValueError:
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
return
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
report.add_error(
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
)
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
report.add_error(
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
)
def _check_event_timestamp_alignment(
self,
row: dict[str, Any],
frame_ts: set[float],
report: ValidationReport,
episode_index: int,
) -> None:
ts = row.get("timestamp")
if ts is None:
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
return
if self.timestamp_atol == 0.0:
if float(ts) not in frame_ts:
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
)
else:
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
)
def _check_speech_interjection_pairs(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
speech_ts: dict[float, int] = {}
interjection_ts: dict[float, int] = {}
for row in events:
ts = row.get("timestamp")
if ts is None:
continue
ts_f = float(ts)
if row.get("style") is None and row.get("role") == "assistant":
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
if row.get("style") == "interjection":
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
for ts in interjection_ts:
if ts not in speech_ts:
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
def _check_plan_memory_consistency(
self,
persistent: Sequence[dict[str, Any]],
events: Sequence[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
interjection_ts = sorted(
{
float(r["timestamp"])
for r in events
if r.get("style") == "interjection" and r.get("timestamp") is not None
}
)
if persistent and not plan_ts:
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
# every interjection should have a same-timestamp plan refresh
for ts in interjection_ts:
if ts not in set(plan_ts):
report.add_error(
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
)
# memory should be emitted at subtask boundaries (subset relation)
if memory_ts and subtask_ts:
mem_set = set(memory_ts)
sub_set = set(subtask_ts)
stray = sorted(mem_set - sub_set)
if stray:
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
def _check_vqa_json(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
for row in events:
if row.get("style") != "vqa" or row.get("role") != "assistant":
continue
content = row.get("content")
if content is None:
report.add_error(
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
)
continue
try:
payload = json.loads(content)
except (TypeError, ValueError) as exc:
report.add_error(
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
)
continue
shape = classify_vqa_answer(payload)
if shape is None:
report.add_error(
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
)
@@ -0,0 +1,617 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared Qwen-VL client.
The pipeline uses a single shared VLM across modules. vLLM is preferred when
available (high throughput, JSON-guided decoding); transformers is the
fallback. A ``stub`` backend is used for unit tests so fixtures never call
into a real model.
The client speaks one method, :meth:`VlmClient.generate_json`, which:
- accepts a list of OpenAI/HF-style multimodal messages,
- requests JSON output from the server,
- batches requests transparently,
- and reprompts once on a JSON parse failure with an inline correction
message before raising.
"""
from __future__ import annotations
import atexit
import base64
import io
import json
import os
import shlex
import signal
import subprocess
import sys
import threading
import time
import urllib.request
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Protocol
from .config import VlmConfig
class VlmClient(Protocol):
"""Protocol every backend must implement."""
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
"""Generate one JSON-decoded response per messages list."""
@dataclass
class StubVlmClient:
"""Deterministic stub used in unit tests.
A test passes a callable that maps the *last user message text* (or, if
that is empty, the full message list) to a JSON-serializable response.
"""
responder: Callable[[Sequence[dict[str, Any]]], Any]
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
return [self.responder(list(messages)) for messages in messages_batch]
def _strip_to_json(text: str) -> Any:
text = text.strip()
# Strip <think>...</think> blocks (Qwen3 Thinking style)
while "<think>" in text and "</think>" in text:
start = text.find("<think>")
end = text.find("</think>", start) + len("</think>")
text = (text[:start] + text[end:]).strip()
# Strip ```json ... ``` fences from chat-tuned backbones
if text.startswith("```"):
first = text.find("\n")
last = text.rfind("```")
if first != -1 and last != -1 and last > first:
text = text[first + 1 : last].strip()
try:
return json.loads(text)
except (ValueError, json.JSONDecodeError):
pass
# Fall back to extracting the first balanced {...} block.
obj_text = _extract_first_json_object(text)
if obj_text is None:
raise json.JSONDecodeError("No JSON object found", text, 0)
return json.loads(obj_text)
def _extract_first_json_object(text: str) -> str | None:
"""Return the first balanced ``{...}`` substring, ignoring braces in
string literals. Returns ``None`` if no balanced block is found."""
start = text.find("{")
if start < 0:
return None
depth = 0
in_string = False
escape = False
for i in range(start, len(text)):
ch = text[i]
if escape:
escape = False
continue
if ch == "\\":
escape = True
continue
# Note: ``escape`` is always False here — the ``if escape`` branch
# above already handled and reset it.
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
@dataclass
class _GenericTextClient:
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
config: VlmConfig
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
temp = temperature if temperature is not None else self.config.temperature
raw = self.generate_text(messages_batch, max_tok, temp)
out: list[Any] = []
for messages, text in zip(messages_batch, raw, strict=True):
try:
out.append(_strip_to_json(text))
continue
except (ValueError, json.JSONDecodeError):
pass
retry = list(messages) + [
{"role": "assistant", "content": text},
{
"role": "user",
"content": (
"Your previous reply was not valid JSON. "
"Reply with strictly valid JSON, no prose, no fences."
),
},
]
retry_text = self.generate_text([retry], max_tok, temp)[0]
try:
out.append(_strip_to_json(retry_text))
except (ValueError, json.JSONDecodeError):
# After retry: log preview and return None instead of crashing
# the whole pipeline. Modules treat None as "skip".
preview = retry_text.strip().replace("\n", " ")[:200]
print(
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
flush=True,
)
out.append(None)
return out
def make_vlm_client(config: VlmConfig) -> VlmClient:
"""Build the shared VLM client.
Only the ``openai`` backend is supported for now. The shipped workflow
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
optionally auto-spawning the server via ``auto_serve`` /
``serve_command``). The former in-process ``vllm`` / ``transformers``
backends were removed to keep the support surface to the HF Jobs path.
For ``stub``, construct :class:`StubVlmClient` directly with a responder
callable; it is rejected here to make accidental misuse obvious.
"""
if config.backend == "openai":
return _make_openai_client(config)
if config.backend == "stub":
raise ValueError(
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
)
if config.backend in {"vllm", "transformers"}:
raise ValueError(
f"backend={config.backend!r} (in-process local model) is not supported for now — "
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
)
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
def _make_openai_client(config: VlmConfig) -> VlmClient:
"""Backend that talks to any OpenAI-compatible server.
Compatible with ``vllm serve``, ``transformers serve``,
``ktransformers serve``, and hosted endpoints. By default the server
is expected to be already running. Set ``auto_serve=True`` to have
this client spawn one (default: ``transformers serve``), wait until
it's ready, and tear it down on process exit.
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
auto-converted to ``image_url`` data-URLs. Video blocks
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
multi-frame ``video_url`` items where supported.
"""
try:
from openai import OpenAI # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError(
"openai package is required for backend='openai'. Install with `pip install openai`."
) from exc
api_base = config.api_base
api_key = config.api_key
auto_serve = config.auto_serve
api_bases: list[str] = [api_base]
print(
f"[lerobot-annotate] backend=openai model={config.model_id} "
f"api_base={api_base} auto_serve={auto_serve}",
flush=True,
)
if auto_serve:
if config.parallel_servers > 1:
print(
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
flush=True,
)
api_bases = _spawn_parallel_inference_servers(config)
elif _server_is_up(api_base):
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
else:
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
api_base = _spawn_inference_server(config)
api_bases = [api_base]
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
# round-robin counter for parallel mode
rr_counter = {"i": 0}
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
# rejects it with HTTP 422. Send it only when explicitly opted in via
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
rr_lock = threading.Lock()
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
api_messages, mm_kwargs = _to_openai_messages(messages)
kwargs: dict[str, Any] = {
"model": config.model_id,
"messages": api_messages,
"max_tokens": max_tok,
"temperature": temp,
}
extra_body: dict[str, Any] = {}
if send_mm_kwargs and mm_kwargs:
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
if config.chat_template_kwargs:
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
if extra_body:
kwargs["extra_body"] = extra_body
with rr_lock:
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
if len(batch) <= 1 or config.client_concurrency <= 1:
return [_one_call(messages, max_tok, temp) for messages in batch]
# Parallel fan-out — vllm batches these on the server side.
max_workers = min(config.client_concurrency, len(batch))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
return [f.result() for f in futures]
return _GenericTextClient(_gen, config)
def _bind_serve_port(cmd: str, port: int) -> str:
"""Bind a serve command to ``port``: substitute a ``{port}`` placeholder
if present, else append ``--port`` when the command omits it (leaving an
explicit ``--port`` untouched). Shared by the single- and parallel-server
paths so a serve_command never reaches the server with a literal
``{port}``."""
if "{port}" in cmd:
return cmd.replace("{port}", str(port))
if "--port" not in cmd:
return f"{cmd} --port {port}"
return cmd
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
"""Spawn ``config.parallel_servers`` independent vllm replicas.
Each replica:
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
- listens on ``serve_port + i``
- is shut down via the same atexit hook as the single-server path
Returns the list of ``api_base`` URLs the client should round-robin
across.
"""
n = config.parallel_servers
api_bases: list[str] = []
procs: list[subprocess.Popen] = []
ready_events: list[threading.Event] = []
# Multiple readiness signals — uvicorn's own banner is suppressed at
# ``--uvicorn-log-level warning``, so we also accept vllm's own
# "Starting vLLM API server" line and the route-listing line. The
# HTTP probe below is the ultimate fallback.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
# Single lock for all server-stream threads so multibyte chars from
# different servers don't interleave and tear UTF-8 sequences.
print_lock = threading.Lock()
base_cmd = config.serve_command or (
f"vllm serve {shlex.quote(config.model_id)} "
f"--tensor-parallel-size 1 "
f"--max-model-len {config.max_model_len or 32768} "
f"--uvicorn-log-level warning"
)
num_gpus = config.num_gpus if config.num_gpus > 0 else n
for i in range(n):
port = config.serve_port + i
gpu = i % num_gpus
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
cmd = _bind_serve_port(base_cmd, port)
api_base = f"http://localhost:{port}/v1"
api_bases.append(api_base)
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
)
procs.append(proc)
ready = threading.Event()
ready_events.append(ready)
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
# Read whole lines and emit each line atomically under the
# shared print_lock so output from N servers stays readable.
assert p.stdout is not None
for line in iter(p.stdout.readline, ""):
with print_lock:
sys.stdout.write(f"[server-{idx}] {line}")
if not line.endswith(("\n", "\r")):
sys.stdout.write("\n")
sys.stdout.flush()
if any(m in line for m in ready_markers):
ev.set()
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
while not ev.is_set() and p.poll() is None:
if _server_is_up(base):
print(f"[server-{idx}] ready (http probe)", flush=True)
ev.set()
return
time.sleep(2)
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
def _shutdown() -> None:
for i, p in enumerate(procs):
if p.poll() is None:
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
p.send_signal(signal.SIGINT)
for p in procs:
try:
p.wait(timeout=15)
except subprocess.TimeoutExpired:
p.kill()
p.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
for i, p in enumerate(procs):
if p.poll() is not None:
raise RuntimeError(
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
)
time.sleep(2)
if any(not ev.is_set() for ev in ready_events):
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
return api_bases
def _server_is_up(api_base: str) -> bool:
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
url = api_base.rstrip("/") + "/models"
# ``api_base`` is the user-configured local-server URL we just spawned
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
# is for arbitrary user-controlled URLs with file:/ schemes which
# cannot reach this code path.
try:
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
return resp.status == 200
except Exception: # noqa: BLE001
return False
def _spawn_inference_server(config: VlmConfig) -> str:
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
accepts ``/v1/models``, and register a shutdown hook.
Streams the server's stdout/stderr to the parent terminal in
real-time on a background thread so users can see model-load
progress and errors as they happen.
Returns the full ``api_base`` URL the OpenAI client should use.
"""
cmd = config.serve_command
if not cmd:
cmd = (
f"transformers serve {shlex.quote(config.model_id)} "
f"--port {config.serve_port} --continuous-batching"
)
# Bind the single server to ``serve_port`` (what ``api_base`` below
# targets): substitute a literal ``{port}`` placeholder, else append
# ``--port``. Without this a serve_command carrying ``{port}`` would
# reach the server unsubstituted and fail to parse.
cmd = _bind_serve_port(cmd, config.serve_port)
api_base = f"http://localhost:{config.serve_port}/v1"
print(f"[server] launching: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
# Watch the server output for the uvicorn readiness banner. This is
# more reliable than polling /v1/models because transformers serve
# rescans its cache on every model-list request, which can exceed
# the urllib timeout and trigger an infinite probe loop.
ready_event = threading.Event()
# See _spawn_parallel_inference_servers for why we accept these.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
def _probe() -> None:
while not ready_event.is_set() and proc.poll() is None:
if _server_is_up(api_base):
print("[server] ready (http probe)", flush=True)
ready_event.set()
return
time.sleep(2)
threading.Thread(target=_probe, daemon=True).start()
def _stream_output() -> None:
# Read raw chunks instead of iterating lines so tqdm progress
# bars (which overwrite using \r) flush in real time.
assert proc.stdout is not None
buf = ""
prefix_started = False
while True:
ch = proc.stdout.read(1)
if ch == "":
# process exited; flush any tail
if buf:
sys.stdout.write(buf)
sys.stdout.flush()
return
if not prefix_started:
sys.stdout.write("[server] ")
prefix_started = True
sys.stdout.write(ch)
sys.stdout.flush()
buf += ch
if ch in ("\n", "\r"):
if any(marker in buf for marker in ready_markers):
ready_event.set()
buf = ""
prefix_started = False
threading.Thread(target=_stream_output, daemon=True).start()
def _shutdown() -> None:
if proc.poll() is None:
print(f"[server] stopping pid={proc.pid}", flush=True)
proc.send_signal(signal.SIGINT)
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while time.monotonic() < deadline:
if proc.poll() is not None:
raise RuntimeError(
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
f"See [server] log lines above for the cause."
)
if ready_event.wait(timeout=2):
return api_base
proc.terminate()
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
def _to_openai_messages(
messages: Sequence[dict[str, Any]],
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
"""Convert internal messages to OpenAI chat format.
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
(``fps`` from ``video_url`` blocks) are extracted out so the caller
can pass them via ``extra_body.mm_processor_kwargs`` rather than
inside the content blocks (which transformers serve rejects).
File-URL video blocks are inlined as base64 data URLs.
"""
out_messages: list[dict[str, Any]] = []
mm_kwargs: dict[str, Any] = {}
for message in messages:
content = message.get("content")
if not isinstance(content, list):
out_messages.append({"role": message["role"], "content": content})
continue
out_blocks: list[dict[str, Any]] = []
for block in content:
block_type = block.get("type") if isinstance(block, dict) else None
if block_type == "text":
out_blocks.append({"type": "text", "text": block.get("text", "")})
elif block_type == "image":
out_blocks.append(
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
)
elif block_type == "video":
frames = block.get("video", [])
for img in frames:
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
elif block_type == "video_url":
video_url = dict(block["video_url"])
url = video_url.get("url", "")
if url.startswith("file://"):
video_url["url"] = _file_to_data_url(url[len("file://") :])
out_blocks.append({"type": "video_url", "video_url": video_url})
fps = block.get("fps")
if fps is not None:
mm_kwargs["fps"] = fps
else:
out_blocks.append(block)
out_messages.append({"role": message["role"], "content": out_blocks})
return out_messages, mm_kwargs
def _file_to_data_url(path: str) -> str:
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
return f"data:video/mp4;base64,{b64}"
def _pil_to_data_url(image: Any) -> str:
"""Encode a PIL.Image as a base64 data URL."""
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/png;base64,{b64}"
@@ -0,0 +1,341 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Final parquet rewrite.
For every episode the writer:
1. reads the staged module outputs,
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
3. sorts each slice deterministically,
4. broadcasts the persistent slice across every frame in the episode,
5. for each frame, materializes the sublist of event rows whose timestamp
exactly equals that frame's timestamp,
6. drops the legacy ``subtask_index`` column,
7. writes the parquet shard back in place.
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
struct for every speech atom. The tool *schema* (the description
of the ``say`` function and its parameters) is a fixed code constant —
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
it directly rather than reading a redundant per-row column.
Invariants enforced here (and re-checked by the validator):
- per-episode persistent slice is byte-identical across every frame;
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
(timestamps come straight from the source parquet — never recomputed);
- every row passes ``column_for_style(style)``.
"""
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
validate_camera_field,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
# Tool schema constants live in lerobot.datasets.language — single
# source of truth. Re-exported here so existing imports
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
# keep working.
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
# events are bucketed per-frame, but within a frame we still want determinism
return (
row.get("style") or "",
row.get("role") or "",
row.get("camera") or "",
)
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
"""Coerce a staged row into the language-column struct shape.
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
writer infers the parquet struct schema from insertion order, so
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
"""
camera = row.get("camera")
validate_camera_field(style, camera)
out: dict[str, Any] = {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
}
if with_timestamp:
out["timestamp"] = float(row["timestamp"])
out["camera"] = None if camera is None else str(camera)
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
return out
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the persistent column's struct shape."""
style = row.get("style")
if style not in PERSISTENT_STYLES:
raise ValueError(
f"persistent slice contains row with non-persistent style {style!r}; "
"row would be misrouted under column_for_style()"
)
if "timestamp" not in row:
raise ValueError(f"persistent row missing timestamp: {row!r}")
if "role" not in row:
# Friendly error from the writer instead of a raw KeyError below;
# the validator doesn't check ``role`` yet.
raise ValueError(f"persistent row missing role: {row!r}")
return _normalize_row(row, style, with_timestamp=True)
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
style = row.get("style")
if style is not None and style not in EVENT_ONLY_STYLES:
raise ValueError(
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
)
if column_for_style(style) != LANGUAGE_EVENTS:
raise ValueError(f"event row with style {style!r} would not route to language_events")
if "role" not in row:
raise ValueError(f"event row missing role: {row!r}")
return _normalize_row(row, style, with_timestamp=False)
def _normalize_tool_calls(value: Any) -> list[Any] | None:
if value is None:
return None
if not isinstance(value, list):
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
return list(value)
def _validate_atom_invariants(row: dict[str, Any]) -> None:
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
has_content = row.get("content") is not None
has_tools = row.get("tool_calls") is not None
if not (has_content or has_tools):
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
if row.get("style") is None and not has_tools:
raise ValueError(f"style=None requires tool_calls: {row!r}")
def _validate_speech_atom(row: dict[str, Any]) -> None:
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
if row.get("style") is not None:
return # not a speech atom
if row.get("role") != "assistant":
raise ValueError(f"speech atom must have role=assistant: {row!r}")
if row.get("content") is not None:
raise ValueError(f"speech atom must have content=null: {row!r}")
tool_calls = row.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
first = tool_calls[0]
if not isinstance(first, dict):
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
if first.get("type") != "function":
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
fn = first.get("function") or {}
if fn.get("name") != "say":
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
args = fn.get("arguments") or {}
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
@dataclass
class LanguageColumnsWriter:
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
drop_existing_subtask_index: bool = True
def write_all(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> list[Path]:
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
for record in records:
episodes_by_path[record.data_path].append(record)
written: list[Path] = []
for path, eps in episodes_by_path.items():
self._rewrite_one(path, eps, staging_dir, root)
written.append(path)
return written
def _rewrite_one(
self,
path: Path,
episodes: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> None:
table = pq.read_table(path)
n_rows = table.num_rows
# Ensure we cover every episode in the file. Episodes that don't have
# staging artifacts are passed through with empty annotation lists —
# this keeps the writer idempotent and safe for partial reruns.
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
for record in episodes:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged_per_ep[record.episode_index] = staging.read_all()
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
for ep_index, ep_staged in staged_per_ep.items():
persistent_rows: list[dict[str, Any]] = []
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
for _module_name, rows in ep_staged.items():
for row in rows:
style = row.get("style")
if column_for_style(style) == LANGUAGE_PERSISTENT:
persistent_rows.append(row)
else:
event_rows.append(row)
persistent_rows.sort(key=_row_persistent_sort_key)
normalized_persistent = []
for r in persistent_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
normalized_persistent.append(_normalize_persistent_row(r))
persistent_by_ep[ep_index] = normalized_persistent
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
for r in event_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
ts = float(r["timestamp"])
buckets[ts].append(_normalize_event_row(r))
for ts in list(buckets.keys()):
buckets[ts].sort(key=_row_event_sort_key)
events_by_ep_ts[ep_index] = buckets
episode_col = (
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
)
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
if episode_col is None or ts_col is None:
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
per_row_persistent: list[list[dict[str, Any]]] = []
per_row_events: list[list[dict[str, Any]]] = []
for i in range(n_rows):
ep = episode_col[i]
ts = float(ts_col[i])
per_row_persistent.append(persistent_by_ep.get(ep, []))
buckets = events_by_ep_ts.get(ep, {})
per_row_events.append(buckets.get(ts, []))
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
# Atomic replace: write to a sibling tmp path and rename so a crash
# mid-write can't leave a half-written shard that ``pq.read_table``
# would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp")
pq.write_table(new_table, tmp_path)
tmp_path.replace(path)
def _materialize_table(
self,
table: pa.Table,
persistent: list[list[dict[str, Any]]],
events: list[list[dict[str, Any]]],
*,
drop_old: bool,
) -> pa.Table:
cols = []
names = []
for name in table.column_names:
if drop_old and name == "subtask_index":
continue
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
continue # we'll re-add canonical versions
# Strip any legacy ``tools`` column previously emitted by older
# writers — the schema no longer uses it (constant lives in
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
if name == "tools":
continue
cols.append(table.column(name))
names.append(name)
# We let pyarrow infer struct/list schema rather than passing the
# canonical type from `lerobot.datasets.language` directly: that type
# uses `pa.json_()` for the `tool_calls` element type, which
# `pa.array(..., type=...)` cannot materialize from Python lists on
# current pyarrow versions. The inferred schema round-trips through
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
# exercises the same flow.
persistent_arr = pa.array(persistent)
events_arr = pa.array(events)
cols.extend([persistent_arr, events_arr])
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
return pa.Table.from_arrays(cols, names=names)
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
"""Build a canonical speech tool-call atom for the events column."""
return {
"role": "assistant",
"content": None,
"style": None,
"timestamp": float(timestamp),
"camera": None,
"tool_calls": [
{
"type": "function",
"function": {
"name": "say",
"arguments": {"text": text},
},
}
],
}
+70
View File
@@ -18,6 +18,7 @@ from __future__ import annotations
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)
########################################################################################
# Teleoperator smooth handover helpers
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
# being a root module.
########################################################################################
def teleop_supports_feedback(teleop) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback (i.e. have non-empty
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
``target_pos`` is expected to be in the teleop's action/feedback key space.
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
the robot action key space directly.
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
follower robot does not receive new actions, which could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def follower_smooth_move_to(
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm to
the follower, the follower is brought to the teleop's current pose so the
robot meets the operator's hand rather than jumping to it on the first frame.
Both ``current`` and ``target`` must be in the robot action key space
(i.e. the output of ``robot_action_processor``).
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
+37 -4
View File
@@ -49,8 +49,19 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def save_training_step(
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
) -> None:
state: dict = {"step": step}
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
# produced `step`, since both scale how many sampler positions a step consumes (see
# compute_sampler_state).
if num_processes is not None:
state["num_processes"] = num_processes
if batch_size is not None:
state["batch_size"] = batch_size
write_json(state, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
@@ -58,6 +69,16 @@ def load_training_step(save_dir: Path) -> int:
return training_step["step"]
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
@@ -75,6 +96,8 @@ def save_checkpoint(
scheduler: LRScheduler | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -100,6 +123,10 @@ def save_checkpoint(
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
num_processes (int | None, optional): Distributed world size to record for sample-exact
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
@@ -112,7 +139,9 @@ def save_checkpoint(
preprocessor.save_pretrained(pretrained_dir)
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
)
def save_training_state(
@@ -120,6 +149,8 @@ def save_training_state(
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
@@ -131,10 +162,12 @@ def save_training_state(
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
+2 -2
View File
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
private: bool | None = None
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
+2
View File
@@ -73,6 +73,8 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
use_async_envs: bool = True
# Whether to record eval rollouts as a LeRobot v3.0 dataset on disk.
recording: bool = False
def __post_init__(self) -> None:
if self.batch_size == 0:
+6
View File
@@ -177,6 +177,12 @@ class TrainPipelineConfig(HubMixin):
)
active_cfg = self.trainable_config
if self.rename_map and active_cfg.pretrained_path is None:
raise ValueError(
"`rename_map` requires a pretrained policy checkpoint. "
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{active_cfg.type}"
+2 -1
View File
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler
from .sampler import EpisodeAwareSampler, compute_sampler_state
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -82,6 +82,7 @@ __all__ = [
"aggregate_stats",
"convert_image_to_video_dataset",
"create_initial_features",
"compute_sampler_state",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
+21 -6
View File
@@ -286,6 +286,8 @@ def aggregate_datasets(
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
chunk_size: int | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -303,6 +305,8 @@ def aggregate_datasets(
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
logging.info("Start aggregate_datasets")
@@ -351,8 +355,12 @@ def aggregate_datasets(
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
videos_idx = aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
)
data_idx = aggregate_data(
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
@@ -367,7 +375,9 @@ def aggregate_datasets(
logging.info("Aggregation complete.")
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
def aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
):
"""Aggregates video chunks from a source dataset into the destination dataset.
Handles video file concatenation and rotation based on file size limits.
@@ -379,6 +389,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
@@ -439,7 +450,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
src_size = get_file_size_in_mb(src_path)
dst_size = get_file_size_in_mb(dst_path)
if dst_size + src_size >= video_files_size_in_mb:
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new file - offset is 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_key = (chunk_idx, file_idx)
@@ -477,7 +488,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
return videos_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
@@ -493,6 +504,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -538,6 +550,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
contains_images=contains_images,
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
)
# Record the mapping from source to actual destination
@@ -614,6 +627,7 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -630,6 +644,7 @@ def append_or_create_parquet_file(
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -649,7 +664,7 @@ def append_or_create_parquet_file(
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if dst_size + src_size >= max_mb:
if not concatenate or dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
+2
View File
@@ -59,6 +59,8 @@ class RunningQuantileStats:
batch: An array where all dimensions except the last are batch dimensions.
"""
batch = batch.reshape(-1, batch.shape[-1])
# Promote integer and low-precision inputs before computing squared statistics.
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
num_elements, vector_length = batch.shape
if self._count == 0:
+6
View File
@@ -261,6 +261,8 @@ def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
@@ -270,6 +272,8 @@ def merge_datasets(
datasets: List of LeRobotDatasets to merge.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -284,6 +288,8 @@ def merge_datasets(
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
concatenate_videos=concatenate_videos,
concatenate_data=concatenate_data,
)
merged_dataset = LeRobotDataset(
+3 -2
View File
@@ -524,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
private: bool | None = None,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
**card_kwargs,
@@ -543,7 +543,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
tag_version: If ``True``, create a Git tag for the current codebase
version.
push_videos: If ``False``, skip uploading the ``videos/`` directory.
private: If ``True``, create a private repository.
private: If ``True``, create a private repository. If ``None``
(default), defer to the org default on the Hub (only affects orgs).
allow_patterns: Glob pattern(s) restricting which files to upload.
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
of ``upload_folder`` for very large datasets.
+122 -32
View File
@@ -14,14 +14,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections.abc import Iterator
import numpy as np
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
"""Sampler over episode frames that stores only per-episode boundaries.
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
instead of materializing a Python list of every frame index.
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
resumed epoch, `__len__` still reports the full length.
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
starts (e.g. on resume or in tests).
"""
def __init__(
self,
dataset_from_indices: list[int],
@@ -30,57 +52,125 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
seed: int = 0,
):
"""Sampler that optionally incorporates episode boundary information.
"""
Args:
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
dataset_from_indices: Start index of each episode in the dataset.
dataset_to_indices: End index of each episode in the dataset.
episode_indices_to_use: Episode indices to use; None means all.
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
seed: Seed the permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
if from_indices.shape != to_indices.shape:
raise ValueError(
f"dataset_from_indices and dataset_to_indices must have the same length, "
f"got {len(from_indices)} and {len(to_indices)}"
)
if not indices:
used = np.ones(len(from_indices), dtype=bool)
if episode_indices_to_use is not None:
used = np.zeros(len(from_indices), dtype=bool)
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
starts = from_indices + drop_n_first_frames
lengths = to_indices - drop_n_last_frames - starts
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
to_indices[episode_idx] - from_indices[episode_idx],
drop_n_first_frames,
drop_n_last_frames,
)
used &= lengths > 0
if not used.any():
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self.indices = indices
self._starts = starts[used]
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.shuffle = shuffle
self.seed = seed
self._epoch = 0
self._start_index = 0
@property
def indices(self) -> list[int]:
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
self._epoch = epoch
def state_dict(self) -> dict:
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _epoch_generator(self, epoch: int) -> torch.Generator:
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
# and reproduces identically on every rank without touching the global RNG.
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
return torch.Generator().manual_seed(epoch_seed)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_epoch(epoch, start)
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
for k in range(start, self._num_frames):
yield self._frame_index(int(order[k]))
else:
for i in self.indices:
yield i
for k in range(start, self._num_frames):
yield self._frame_index(k)
def __len__(self) -> int:
return len(self.indices)
return self._num_frames
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
Assumptions (resume is only sample-exact when they hold):
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
many positions a step consumes, so the epoch/offset are wrong if either changed. The
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
the boundary is off.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
return {"epoch": epoch, "start_index": start_index}
+21 -1
View File
@@ -481,8 +481,10 @@ def reencode_video(
encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
start_time_s: float | None = None,
end_time_s: float | None = None,
) -> None:
"""Re-encode a video file using the given encoder configuration.
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
Args:
input_video_path: Existing video file to read.
@@ -491,10 +493,17 @@ def reencode_video(
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
start_time_s: When set, trim the output to start at this timestamp (seconds).
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
"""
camera_encoder = camera_encoder or camera_encoder_defaults()
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
@@ -526,6 +535,10 @@ def reencode_video(
width = int(in_stream.width)
height = int(in_stream.height)
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
if start_time_s is not None:
src.seek(int(start_time_s * av.time_base), backward=True)
with av.open(
tmp_output_video_path,
mode="w",
@@ -539,7 +552,14 @@ def reencode_video(
out_stream.height = height
for frame in src.decode(in_stream):
frame_time_s = frame.time
if start_time_s is not None and frame_time_s < start_time_s:
continue
if end_time_s is not None and frame_time_s >= end_time_s:
break
frame = frame.reformat(width=width, height=height, format=pix_fmt)
if start_time_s is not None:
frame.pts = None # reset timestamps so the trimmed output starts at t=0
packet = out_stream.encode(frame)
if packet:
dst.mux(packet)
+2
View File
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
@@ -43,6 +44,7 @@ __all__ = [
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"MolmoAct2Config",
"MultiTaskDiTConfig",
"PI0Config",
"PI0FastConfig",
+40 -2
View File
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
@@ -56,6 +57,7 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -88,7 +90,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
"molmoact2".
Returns:
The policy class corresponding to the given name.
@@ -151,6 +154,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "molmoact2":
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
return MolmoAct2Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -168,7 +179,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x".
"smolvla", "wall_x", "molmoact2".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -203,6 +214,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "molmoact2":
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -231,6 +246,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
dataset_meta: Any | None
def make_pre_post_processors(
@@ -406,6 +422,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -414,6 +431,23 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MolmoAct2Config):
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
processors = make_molmoact2_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
@@ -499,6 +533,10 @@ def make_policy(
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
if ds_meta is not None:
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
if callable(set_dataset_feature_metadata):
set_dataset_feature_metadata(ds_meta.features)
kwargs["config"] = cfg
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_molmoact2_README.md
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_molmoact2 import MolmoAct2Config
from .modeling_molmoact2 import MolmoAct2Policy
from .processor_molmoact2 import make_molmoact2_pre_post_processors
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]
@@ -0,0 +1,519 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import math
import os
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import (
AdamWConfig,
CosineDecayWithWarmupSchedulerConfig,
LRSchedulerConfig,
OptimizerConfig,
)
from lerobot.utils.constants import ACTION, OBS_STATE
from ..rtc.configuration_rtc import RTCConfig
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
@dataclass
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
training steps"; build() is the first point where num_training_steps is known.
"""
num_decay_steps: int | None
def build(self, optimizer, num_training_steps: int):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.peak_lr,
decay_lr=self.decay_lr,
num_warmup_steps=self.num_warmup_steps,
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
).build(optimizer, num_training_steps=num_training_steps)
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
@PreTrainedConfig.register_subclass("molmoact2")
@dataclass
class MolmoAct2Config(PreTrainedConfig):
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
checkpoint_path: str = "allenai/MolmoAct2"
checkpoint_revision: str | None = None
checkpoint_force_download: bool = False
n_obs_steps: int = 1
chunk_size: int = 30
n_action_steps: int = 30
action_mode: str = "both"
inference_action_mode: str | None = None
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
discrete_generation_max_steps: int | None = None
norm_tag: str | None = None
setup_type: str = ""
control_mode: str = ""
image_keys: list[str] = field(default_factory=list)
normalize_language: bool = True
add_setup_tokens: bool = True
add_control_tokens: bool = True
normalize_gripper: bool = False
num_state_tokens: int = 256
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
# image/prompt/state/action token layout. Override only for unusual long prompts.
max_sequence_length: int | None = None
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
expected_max_action_dim: int = 32
# Flow-matching training knobs copied from the original MolmoAct2 training path.
num_flow_timesteps: int = 8
flow_matching_cutoff: float = 1.0
flow_matching_time_offset: float = 0.001
flow_matching_time_scale: float = 0.999
flow_matching_beta_alpha: float = 1.0
flow_matching_beta_beta: float = 1.5
num_inference_steps: int | None = None
mask_action_dim_padding: bool = True
enable_inference_cuda_graph: bool = True
# MolmoAct2-local eval option. When enabled, stochastic continuous action
# generation uses a rollout-local generator derived from eval_seed.
per_episode_seed: bool = False
eval_seed: int | None = None
rtc_config: RTCConfig | None = None
# Default is full finetuning with gradients from the action expert flowing into the VLM.
enable_lora_vlm: bool = False
lora_rank: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_bias: str = "none"
enable_lora_action_expert: bool = False
enable_knowledge_insulation: bool = False
freeze_embedding: bool = True
train_action_expert_only: bool = False
gradient_checkpointing: bool = False
model_dtype: str = "bfloat16"
softmax_auxiliary_loss: bool = True
softmax_auxiliary_loss_scale: float = 1e-4
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
optimizer_lr: float = 1e-5
optimizer_vit_lr: float = 5e-6
optimizer_connector_lr: float = 5e-6
optimizer_action_expert_lr: float = 5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-6
optimizer_weight_decay: float = 0.0
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 200
scheduler_decay_steps: int | None = None
scheduler_decay_lr: float = 1e-6
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES,
"ACTION": NormalizationMode.QUANTILES,
}
)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
super().__post_init__()
if self.action_mode not in {"continuous", "discrete", "both"}:
raise ValueError(
f"Unsupported action_mode={self.action_mode!r}. "
"Expected one of {'continuous', 'discrete', 'both'}."
)
if self.inference_action_mode not in {None, "continuous", "discrete"}:
raise ValueError(
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
"Expected one of {None, 'continuous', 'discrete'}."
)
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
if self.train_action_expert_only and self.action_mode != "continuous":
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
if self.train_action_expert_only and self.enable_lora_vlm:
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
if self.enable_lora_action_expert and not self.enable_lora_vlm:
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
if self.chunk_size < 1:
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
if self.n_action_steps < 1:
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
)
if self.expected_max_action_dim != 32:
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
raise ValueError(
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
)
if self.lora_rank < 1:
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
if self.lora_alpha < 1:
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
if not 0 <= self.lora_dropout <= 1:
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
if self.lora_bias not in {"none", "all", "lora_only"}:
raise ValueError(
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
)
if self.discrete_loss_token_weighting not in {
"none",
"token",
"root_tokens",
"root_subsegments",
"root_subsegments_root_tokens",
}:
raise ValueError(
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
)
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
raise ValueError(
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
)
if self.max_sequence_length is not None and self.max_sequence_length < 1:
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
def inferred_max_sequence_length(
self,
*,
num_images: int | None = None,
state_dim: int | None = None,
action_dim: int | None = None,
action_horizon: int | None = None,
include_discrete_action: bool | None = None,
) -> int:
if self.max_sequence_length is not None:
return int(self.max_sequence_length)
if num_images is None:
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim is None:
state_feature = self.robot_state_feature
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
if action_dim is None:
action_feature = self.action_feature
action_dim = (
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
)
if action_horizon is None:
action_horizon = self.chunk_size
if include_discrete_action is None:
include_discrete_action = self.action_mode in {"discrete", "both"}
return infer_molmoact2_max_sequence_length(
num_images=int(num_images),
state_dim=int(state_dim),
action_dim=int(action_dim),
action_horizon=int(action_horizon),
include_discrete_action=bool(include_discrete_action),
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
self.dataset_feature_names = {}
for key in (ACTION, OBS_STATE):
feature = features.get(key) if isinstance(features, dict) else None
if isinstance(feature, dict) and feature.get("names") is not None:
self.dataset_feature_names[key] = feature["names"]
def validate_features(self) -> None:
"""Validate and set up MolmoAct2 input and output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"MolmoAct2 policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(0,),
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.expected_max_action_dim,),
)
self.output_features[ACTION] = action_feature
def apply_norm_tag_metadata(self) -> None:
if not str(self.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.checkpoint_path,
revision=self.checkpoint_revision,
force_download=bool(self.checkpoint_force_download),
norm_tag=self.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.n_action_steps = int(metadata["n_action_steps"])
if not self.setup_type and metadata.get("setup_type") is not None:
self.setup_type = str(metadata["setup_type"])
if not self.control_mode and metadata.get("control_mode") is not None:
self.control_mode = str(metadata["control_mode"])
def saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or self.action_mode
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
requested_mode = self.inference_action_mode
if requested_mode is None:
return
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def validate_checkpoint_action_mode(
self,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
if self.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def resolve_inference_action_mode(
self,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode is None:
requested_mode = self.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
@@ -0,0 +1,17 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
@@ -0,0 +1,237 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
import logging
import os
from pathlib import Path
from typing import ClassVar
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_tokenizer_location(
tokenizer_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
local_path = Path(str(tokenizer_path)).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=str(tokenizer_path),
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
class UniversalActionProcessor(ProcessorMixin):
attributes: ClassVar[list[str]] = ["tokenizer"]
tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
scale: float = 10,
vocab_size: int = 1024,
min_token: int = 0,
*,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.scale = scale
self.vocab_size = vocab_size
self.min_token = min_token
# Action horizon and dimension needed during decoding. These can be specified
# in three ways (in order of priority):
# 1. passed in as kwargs to decode()
# 2. in the constructor
# 3. cached from the last time decode() was called
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
super().__init__(tokenizer)
self.bpe_tokenizer = self.tokenizer
def __call__(self, action_chunk: np.array) -> np.array:
from scipy.fft import dct
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
# Cache the time horizon and action dimension for decoding
self.called_time_horizon = action_chunk.shape[-2]
self.called_action_dim = action_chunk.shape[-1]
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
dct_coeff = np.around(dct_coeff * self.scale)
tokens = []
for elem in dct_coeff:
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
from scipy.fft import idct
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@classmethod
def fit(
cls,
action_data: list[np.array],
scale: float = 10,
vocab_size: int = 1024,
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "UniversalActionProcessor":
from scipy.fft import dct
# Run DCT over all inputs
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
# Quantize and find min token
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
assert min_vocab_size <= vocab_size, (
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
)
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
f"size {vocab_size}, consider increasing vocab size"
)
# Make token iterator for BPE training
def _token_iter():
for tokens in dct_tokens:
rounded_tokens = np.around(tokens * scale) - min_token
rounded_tokens = rounded_tokens.astype(int)
string = "".join(map(chr, rounded_tokens))
yield string
# Train BPE tokenizer
bpe = ByteLevelBPETokenizer()
# Set up the entire range of possible tokens as the initial alphabet
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
# because it doesn't support custom alphabets)
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
return cls(
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)
@classmethod
def from_pretrained_local(
cls,
pretrained_model_name_or_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> "UniversalActionProcessor":
location = Path(
_resolve_tokenizer_location(
pretrained_model_name_or_path,
revision=revision,
force_download=force_download,
)
)
processor_config = {}
processor_config_path = location / "processor_config.json"
if processor_config_path.exists():
import json
processor_config = json.loads(processor_config_path.read_text())
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
return cls(
tokenizer,
scale=processor_config.get("scale", 10),
vocab_size=processor_config.get("vocab_size", 1024),
min_token=processor_config.get("min_token", 0),
action_dim=processor_config.get("action_dim"),
time_horizon=processor_config.get("time_horizon"),
)
@@ -0,0 +1,553 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
MolmoAct2 configuration
"""
from typing import Optional, Any
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MolmoAct2VitConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
>>> # Initializing a MolmoAct2VitConfig
>>> configuration = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
>>> model = MolmoAct2VisionTransformer(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
base_config_key = "vit_config"
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
num_hidden_layers: int = 27,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
hidden_act: str = "gelu_pytorch_tanh",
layer_norm_eps: float = 1e-6,
image_default_input_size: tuple[int, int] = (378, 378),
image_patch_size: int = 14,
image_num_pos: int = 577,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
initializer_range: float = 0.02,
float32_attention: bool = True,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.image_default_input_size = image_default_input_size
self.image_patch_size = image_patch_size
self.image_num_pos = image_num_pos
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.initializer_range = initializer_range
self.float32_attention = float32_attention
@property
def image_num_patch(self):
h, w = self.image_default_input_size
return h // self.image_patch_size, w // self.image_patch_size
class MolmoAct2AdapterConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
>>> vit_config = MolmoAct2VitConfig()
>>> adapter_config = MolmoPoolingConfig()
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
>>> # Accessing the model configuration
>>> vit_configuration = model.vit_config
>>> adapter_configuration = model.adapter_config
```"""
model_type = "molmoact2"
base_config_key = "adapter_config"
def __init__(
self,
vit_layers: tuple = (-3, -9),
pooling_attention_mask: bool = False,
hidden_size: int = 1152,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
float32_attention: bool = True,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
hidden_act: str = "silu",
intermediate_size: int = 18944,
text_hidden_size: int = 3584,
image_feature_dropout: float = 0.0,
initializer_range: float = 0.02,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.vit_layers = vit_layers
self.pooling_attention_mask = pooling_attention_mask
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.float32_attention = float32_attention
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.text_hidden_size = text_hidden_size
self.image_feature_dropout = image_feature_dropout
self.initializer_range = initializer_range
class MolmoAct2TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
>>> # Initializing a MolmoAct2TextConfig
>>> configuration = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2TextModel (with random weights)
>>> model = MolmoAct2TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"blocks.*.self_attn.att_proj": "colwise",
"blocks.*.self_attn.attn_out": "rowwise",
"blocks.*.mlp.ff_proj": "colwise",
"blocks.*.mlp.ff_out": "rowwise",
}
base_model_pp_plan = {
"wte": (["input_ids"], ["inputs_embeds"]),
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
"ln_f": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
hidden_size: int = 3584,
num_attention_heads: int = 28,
num_key_value_heads: int | None = 4,
head_dim: int = 128,
vocab_size: int = 152064,
additional_vocab_size: int = 128,
qkv_bias: bool = True,
num_hidden_layers: int = 48,
intermediate_size: int = 18944,
hidden_act: str = "silu",
embedding_dropout: float = 0.0,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
max_position_embeddings: int = 4096,
rope_theta: float = 1000000.0,
rope_scaling: dict[str, Any] = None,
rope_scaling_layers: list[int] | None = None,
use_qk_norm: bool = False,
qk_norm_type: str = "olmo",
layer_norm_eps: int = 1e-6,
norm_after: bool = False,
initializer_range: float = 0.02,
use_cache=True,
tie_word_embeddings=False,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.vocab_size = vocab_size
self.additional_vocab_size = additional_vocab_size
self.qkv_bias = qkv_bias
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.embedding_dropout = embedding_dropout
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rope_scaling_layers = rope_scaling_layers
self.use_qk_norm = use_qk_norm
self.qk_norm_type = qk_norm_type
self.layer_norm_eps = layer_norm_eps
self.norm_after = norm_after
self.initializer_range = initializer_range
self.use_cache = use_cache
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
class MolmoAct2ActionExpertConfig(PretrainedConfig):
r"""Configuration for the MolmoAct2 modern action expert."""
model_type = "molmoact2_action_expert"
base_config_key = "action_expert_config"
def __init__(
self,
max_action_horizon: int = 32,
max_action_dim: int = 32,
hidden_size: int = 1024,
num_layers: int = 32,
num_heads: int = 16,
mlp_ratio: float = 8.0 / 3.0,
ffn_multiple_of: int = 256,
timestep_embed_dim: int = 256,
dropout: float = 0.0,
attn_dropout: float = 0.0,
context_layer_norm: bool = True,
qk_norm: bool = True,
qk_norm_eps: float = 1e-6,
rope: bool = True,
causal_attn: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.max_action_horizon = max_action_horizon
self.max_action_dim = max_action_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.ffn_multiple_of = ffn_multiple_of
self.timestep_embed_dim = timestep_embed_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.context_layer_norm = context_layer_norm
self.qk_norm = qk_norm
self.qk_norm_eps = qk_norm_eps
self.rope = rope
self.causal_attn = causal_attn
def to_dict(self):
output = super().to_dict()
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
# them out of the public nested config avoids duplicated sources of truth.
output.pop("max_action_horizon", None)
output.pop("max_action_dim", None)
return output
class MolmoAct2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
Example:
```python
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
>>> # Initializing a MolmoAct2VitConfig
>>> vit_config = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2AdapterConfig
>>> adapter_config = MolmoAct2AdapterConfig()
>>> # Initializing a MolmoAct2TextConfig
>>> text_config = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2Config
>>> configuration = MolmoAct2Config(
>>> vit_config=vit_config,
>>> adapter_config=adapter_config,
>>> text_config=text_config,
>>> image_start_token_id=151936,
>>> image_end_token_id=151937,
>>> image_patch_id=151938,
>>> image_col_id=151939,
>>> low_res_image_start_token_id=151940,
>>> image_low_res_id=151942,
>>> frame_start_token_id=151943,
>>> frame_end_token_id=151944,
>>> )
>>> # Initializing a model
>>> model = MolmoAct2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
sub_configs = {
"text_config": MolmoAct2TextConfig,
"vit_config": MolmoAct2VitConfig,
"adapter_config": MolmoAct2AdapterConfig,
"action_expert_config": MolmoAct2ActionExpertConfig,
}
def __init__(
self,
vit_config: MolmoAct2VitConfig = None,
adapter_config: MolmoAct2AdapterConfig = None,
text_config: MolmoAct2TextConfig = None,
action_expert_config: MolmoAct2ActionExpertConfig = None,
image_start_token_id: int = None,
low_res_image_start_token_id: int = None,
image_end_token_id: int = None,
image_low_res_id: int = None,
image_patch_id: int = None,
image_col_id: int = None,
frame_start_token_id: int = None,
frame_end_token_id: int = None,
use_frame_special_tokens: bool = True,
initializer_range: float = 0.02,
add_action_expert: bool = True,
max_action_dim: int = 32,
max_action_horizon: int = 30,
n_obs_steps: int = 30,
action_mode: str = "both",
state_format: str = "discrete",
flow_matching_num_steps: int = 10,
flow_matching_cutoff: float = 1.0,
flow_matching_time_offset: float = 0.001,
flow_matching_time_scale: float = 0.999,
flow_matching_beta_alpha: float = 1.0,
flow_matching_beta_beta: float = 1.5,
mask_action_dim_padding: bool = True,
enable_depth_reasoning: bool = False,
depth_mode: int = 2,
num_depth_codes: int = 100,
action_expert_depth_gate: bool = False,
action_expert_depth_gate_per_layer: bool = False,
action_expert_depth_gate_init_bias: float = -4.0,
action_output_token_id: int = None,
action_start_token_id: int = None,
action_end_token_id: int = None,
action_token_start_id: int = None,
num_action_tokens: int = 0,
depth_output_token_id: int = None,
depth_start_token_id: int = None,
depth_end_token_id: int = None,
depth_token_start_id: int = None,
num_depth_tokens: int = 0,
state_start_token_id: int = None,
state_end_token_id: int = None,
state_token_start_id: int = None,
num_state_tokens: int = 0,
add_setup_tokens: bool = True,
add_control_tokens: bool = True,
norm_stats_filename: str = "norm_stats.json",
**kwargs,
):
super().__init__(**kwargs)
if vit_config is None:
self.vit_config = MolmoAct2VitConfig()
elif isinstance(vit_config, dict):
self.vit_config = MolmoAct2VitConfig(**vit_config)
else:
self.vit_config = vit_config
if adapter_config is None:
self.adapter_config = MolmoAct2AdapterConfig()
elif isinstance(adapter_config, dict):
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
else:
self.adapter_config = adapter_config
if text_config is None:
self.text_config = MolmoAct2TextConfig()
elif isinstance(text_config, dict):
self.text_config = MolmoAct2TextConfig(**text_config)
else:
self.text_config = text_config
self.add_action_expert = bool(add_action_expert)
if not self.add_action_expert:
self.action_expert_config = None
elif action_expert_config is None:
self.action_expert_config = MolmoAct2ActionExpertConfig(
max_action_horizon=max_action_horizon,
max_action_dim=max_action_dim,
num_layers=self.text_config.num_hidden_layers,
)
elif isinstance(action_expert_config, dict):
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
else:
self.action_expert_config = action_expert_config
if self.add_action_expert:
self.action_expert_config.max_action_dim = int(max_action_dim)
self.action_expert_config.max_action_horizon = int(max_action_horizon)
self._validate_release_action_config(
state_format=state_format,
)
self.image_start_token_id = image_start_token_id
self.low_res_image_start_token_id = low_res_image_start_token_id
self.image_end_token_id = image_end_token_id
self.image_low_res_id = image_low_res_id
self.image_high_res_id = image_patch_id
self.image_patch_id = image_patch_id
self.image_col_id = image_col_id
self.frame_start_token_id = frame_start_token_id
self.frame_end_token_id = frame_end_token_id
self.use_frame_special_tokens = use_frame_special_tokens
self.initializer_range = initializer_range
self.max_action_dim = max_action_dim
self.max_action_horizon = max_action_horizon
self.n_obs_steps = n_obs_steps
self.action_mode = action_mode
self.state_format = state_format
self.flow_matching_num_steps = flow_matching_num_steps
self.flow_matching_cutoff = flow_matching_cutoff
self.flow_matching_time_offset = flow_matching_time_offset
self.flow_matching_time_scale = flow_matching_time_scale
self.flow_matching_beta_alpha = flow_matching_beta_alpha
self.flow_matching_beta_beta = flow_matching_beta_beta
self.mask_action_dim_padding = mask_action_dim_padding
self.enable_depth_reasoning = enable_depth_reasoning
self.depth_mode = depth_mode
self.num_depth_codes = num_depth_codes
self.action_expert_depth_gate = action_expert_depth_gate
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
self.action_output_token_id = action_output_token_id
self.action_start_token_id = action_start_token_id
self.action_end_token_id = action_end_token_id
self.action_token_start_id = action_token_start_id
self.num_action_tokens = num_action_tokens
self.depth_output_token_id = depth_output_token_id
self.depth_start_token_id = depth_start_token_id
self.depth_end_token_id = depth_end_token_id
self.depth_token_start_id = depth_token_start_id
self.num_depth_tokens = num_depth_tokens
self.state_start_token_id = state_start_token_id
self.state_end_token_id = state_end_token_id
self.state_token_start_id = state_token_start_id
self.num_state_tokens = num_state_tokens
self.add_setup_tokens = add_setup_tokens
self.add_control_tokens = add_control_tokens
self.norm_stats_filename = norm_stats_filename
@staticmethod
def _validate_release_action_config(
*,
state_format: str,
) -> None:
if state_format != "discrete":
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
@property
def image_num_patch(self):
assert self.vit_config is not None
return self.vit_config.image_num_patch
@property
def num_attention_heads(self):
return self.text_config.num_attention_heads
@property
def num_key_value_heads(self):
return self.text_config.num_key_value_heads
@property
def head_dim(self):
return self.text_config.head_dim
@property
def num_hidden_layers(self):
return self.text_config.num_hidden_layers
@property
def hidden_size(self):
return self.text_config.hidden_size
@property
def vocab_size(self):
return self.text_config.vocab_size
@property
def max_position_embeddings(self):
return self.text_config.max_position_embeddings
MolmoAct2VitConfig.register_for_auto_class()
MolmoAct2AdapterConfig.register_for_auto_class()
MolmoAct2TextConfig.register_for_auto_class()
MolmoAct2ActionExpertConfig.register_for_auto_class()
MolmoAct2Config.register_for_auto_class()
@@ -0,0 +1,564 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Image processor class for MolmoAct2"""
from typing import Optional, Union
import numpy as np
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
valid_images,
to_numpy_array,
)
from transformers.image_transforms import convert_to_rgb
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def select_tiling(h, w, patch_size, max_num_crops):
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_crops + 1):
for j in range(1, max_num_crops + 1):
if i * j <= max_num_crops:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
# How much we would need to scale the image to fit exactly in each tiling
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
# The original size can be zero in rare cases if the image is smaller than the margin
# In those cases letting the scale become infinite means the tiling is based on the
# other side, or falls back to the smallest tiling
with np.errstate(divide="ignore"):
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
ix = np.argmax(required_scale)
else:
# Pick the resolution that required the least upscaling so that it most closely fits the image
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
ix = np.argmin(required_scale)
return candidate_tilings[ix]
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def build_overlapping_crops(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Decompose an image into a set of overlapping crops
:return crop_arr: [n_crops, h, w, 3] The crops
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
the crops were extracted from, what patch in `crop_arr` it corresponds to
"""
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
assert base_image_input_size[0] == base_image_input_size[1]
left_margin, right_margin = overlap_margins
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
crop_window_size = crop_window_patches * image_patch_size
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
# Decide how to tile the image, to account for the overlap margins we compute the tiling
# as if we had an image without the margins and were using a crop size without the margins
tiling = select_tiling(
original_image_h - total_margin_pixels,
original_image_w - total_margin_pixels,
crop_window_size,
max_crops,
)
src = resize_image(
image,
[
tiling[0] * crop_window_size + total_margin_pixels,
tiling[1] * crop_window_size + total_margin_pixels,
],
resample,
)
src = normalize_image(src, image_mean, image_std)
# Now we have to split the image into crops, and track what patches came from
# where in `patch_idx_arr`
n_crops = tiling[0] * tiling[1]
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
on_crop = 0
for i in range(tiling[0]):
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
# which results in overlapping crop windows
y0 = i * crop_window_size
for j in range(tiling[1]):
x0 = j * crop_window_size
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
patch_idx += on_crop * crop_patch_h * crop_patch_w
# Mask out idx that are in the overlap region
if i != 0:
patch_idx[:left_margin, :] = -1
if j != 0:
patch_idx[:, :left_margin] = -1
if i != tiling[0] - 1:
patch_idx[-right_margin:, :] = -1
if j != tiling[1] - 1:
patch_idx[:, -right_margin:] = -1
patch_idx_arr[on_crop] = patch_idx
on_crop += 1
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
# so it is ordered left-to-right order
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
# Now get the parts not in the overlap region, so it should map each patch in `src`
# to the correct patch it should come from in `crop_arr`
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
src.shape[0] // image_patch_size,
src.shape[1] // image_patch_size,
)
return crop_arr, patch_idx_arr
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
crop_mode: str = "overlap-and-resize-c2",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each (low-res, high-res) image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
base_image_input_d = image_patch_size
pooling_w = image_pooling_w
pooling_h = image_pooling_h
crop_patch_w = base_image_input_size[1] // base_image_input_d
crop_patch_h = base_image_input_size[0] // base_image_input_d
if crop_mode == "resize":
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [np.array([resized_h, resized_w, 0, 0])]
return (
np.stack(image_grid, 0),
batch_pixels_to_patches(resized, image_patch_size),
resize_idx,
)
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
crop_arr, patch_idx_arr = build_overlapping_crops(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
# Finally do the same for the global image
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
crop_arr = np.concatenate([resized, crop_arr], 0)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
# Global image goes first, so the order of patches in previous crops gets increased
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
pooling_idx = np.concatenate([resize_idx, pooling_idx])
image_grid = [np.array([resized_h, resized_w, h, w])]
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
max_crops: int | None
overlap_margins: list[int] | None
crop_mode: str | None
patch_size: int | None
pooling_size: list[int] | None
class MolmoAct2ImageProcessor(BaseImageProcessor):
r"""
Constructs a MolmoAct2 image processor that preprocesses images for the model.
Args:
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use when resizing the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `8`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to 14):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
The pooling size of the vision adapter.
"""
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
def __init__(
self,
size: dict[str, int] | None = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
max_crops: int = 8,
overlap_margins: list[int] = [4, 4],
crop_mode: str = "overlap-and-resize-c2",
patch_size: int = 14,
pooling_size: list[int] = [2, 2],
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 378, "width": 378}
size = get_size_dict(size, default_to_square=True)
self.size = size
self.resample = resample
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_convert_rgb = do_convert_rgb
self.max_crops = max_crops
self.overlap_margins = overlap_margins
self.crop_mode = crop_mode
self.patch_size = patch_size
self.pooling_size = pooling_size
def preprocess(
self,
images: ImageInput,
size: dict[str, int] | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
max_crops: int | None = None,
overlap_margins: list[int] | None = None,
crop_mode: str | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Args:
images (`ImageInput`):
Image to preprocess.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `self.max_crops`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values`: The preprocessed images.
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
- `image_grids`: The image grids.
- `image_num_crops`: The number of crops for each image.
"""
if size is not None:
if "height" not in size or "width" not in size:
raise ValueError("size must contain 'height' and 'width' keys.")
else:
size = {**self.size}
base_image_input_size = [size["height"], size["width"]]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
max_crops = max_crops or self.max_crops
overlap_margins = overlap_margins or self.overlap_margins
crop_mode = crop_mode or self.crop_mode
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
if images is not None:
images = self.fetch_images(images)
images = make_flat_list_of_images(images)
if images is not None and not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
data = {}
if images is not None:
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
batch_num_crops = []
for image in images:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
crop_mode,
)
batch_grids.append(image_grid)
batch_crops.append(crops)
batch_pooled_patches_idx.append(pooled_idx)
batch_num_crops.append(crops.shape[0])
pixel_values = np.concatenate(batch_crops, 0)
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
image_grids = np.concatenate(batch_grids, 0)
image_num_crops = np.array(batch_num_crops)
data.update(
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2ImageProcessor.register_for_auto_class()
@@ -0,0 +1,748 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Inference utilities for MolmoAct2"""
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from collections.abc import Iterable, Sequence
import torch
from torch.nn import functional as F
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
@dataclass
class _ActionFlowInputs:
trajectory: torch.Tensor
context: Any
modulations: Sequence[Any]
action_dim_is_pad: torch.Tensor | None
@dataclass
class _ActionFlowCudaGraph:
key: tuple[Any, ...]
graph: torch.cuda.CUDAGraph
static_inputs: _ActionFlowInputs
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphLayerStage:
residual: torch.Tensor
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphPostStage:
graph: torch.cuda.CUDAGraph
attn_context: torch.Tensor
@dataclass
class _DepthDecodeCudaGraph:
cache_key: tuple[Any, ...]
pre_graph: torch.cuda.CUDAGraph
token_ids: torch.Tensor
cos: torch.Tensor
sin: torch.Tensor
positions: torch.Tensor
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphSpec:
eligible: bool
cache_key_prefix: tuple[Any, ...]
num_hidden_layers: int
head_dim: int
num_attention_heads: int
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return 0
seq_len = past_key_values.get_seq_length()
if torch.is_tensor(seq_len):
return int(seq_len.item())
return int(seq_len)
def _cache_max_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return -1
max_len = past_key_values.get_max_cache_shape()
if torch.is_tensor(max_len):
return int(max_len.item())
return int(max_len)
def _iter_cache_key_values(
past_key_values: Cache,
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
layers = getattr(past_key_values, "layers", None)
if layers is not None:
for layer in layers:
yield getattr(layer, "keys", None), getattr(layer, "values", None)
return
for layer in past_key_values:
yield layer[0], layer[1]
class _DepthDecodeStaticLayerCache:
is_compileable = False
is_sliding = False
def __init__(self, max_cache_len: int) -> None:
self.max_cache_len = int(max_cache_len)
self.cumulative_length = 0
self.keys: torch.Tensor | None = None
self.values: torch.Tensor | None = None
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
bsz, n_heads = key_states.shape[:2]
self.keys = torch.empty(
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
dtype=key_states.dtype,
device=key_states.device,
)
self.values = torch.empty(
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
dtype=value_states.dtype,
device=value_states.device,
)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
*args,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.keys is None:
self._allocate(key_states, value_states)
start = self.cumulative_length
end = start + key_states.shape[-2]
if end > self.max_cache_len:
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
self.keys[:, :, start:end, :].copy_(key_states)
self.values[:, :, start:end, :].copy_(value_states)
self.cumulative_length = end
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
def get_seq_length(self) -> int:
return self.cumulative_length
def get_max_cache_shape(self) -> int:
return -1
def reset(self) -> None:
self.cumulative_length = 0
class _DepthDecodeStaticCache(Cache):
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
text_config = config.get_text_config(decoder=True)
super().__init__(
layers=[
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
for _ in range(text_config.num_hidden_layers)
]
)
def get_seq_length(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_max_cache_shape()
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class ActionCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.enabled = True
self.action_flow_graph: _ActionFlowCudaGraph | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
action_model = self.model
if not self.enabled:
return False
if action_model.training or action_model._require_action_expert().training:
return False
if inputs.trajectory.device.type != "cuda":
return False
def all_on_cuda():
yield inputs.trajectory
for k, v in inputs.context.kv_contexts:
yield k
yield v
for t in (
inputs.context.cross_mask,
inputs.context.self_mask,
inputs.context.valid_action,
inputs.action_dim_is_pad,
):
if t is not None:
yield t
if inputs.context.rope_cache is not None:
yield from inputs.context.rope_cache
for step in inputs.modulations:
yield step.conditioning
for block_modulation in step.block_modulations:
yield from block_modulation
yield from step.final_modulation
return all(t.device.type == "cuda" for t in all_on_cuda())
def run_action_flow(
self,
inputs: _ActionFlowInputs,
steps: int,
run_loop,
) -> torch.Tensor:
key = _cuda_graph_key(inputs, steps)
cache = self.action_flow_graph
if cache is None or cache.key != key:
static_inputs = _clone_static_inputs(inputs)
graph, output = _capture_cuda_graph(
lambda: run_loop(static_inputs, steps),
inputs.trajectory.device,
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
)
cache = _ActionFlowCudaGraph(
key=key,
graph=graph,
static_inputs=static_inputs,
output=output,
)
self.action_flow_graph = cache
else:
_copy_inputs_(cache.static_inputs, inputs)
cache.graph.replay()
return cache.output.clone()
class DepthDecodeCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.backbone = model.model
self.enabled = True
self.graph: _DepthDecodeCudaGraph | None = None
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
return _DepthDecodeStaticCache(
config=self.model.config.text_config,
max_cache_len=max_cache_len,
)
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
static = self.graph_spec
if static is None:
cfg = self.backbone.transformer.config
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
static = _DepthDecodeCudaGraphSpec(
eligible=(
not cfg.norm_after
and cfg.rope_scaling_layers is None
and getattr(rotary_emb, "rope_type", None) == "default"
and cfg._attn_implementation == "sdpa"
),
cache_key_prefix=(
cfg.hidden_size,
cfg.num_attention_heads,
cfg.num_key_value_heads,
cfg.head_dim,
cfg.num_hidden_layers,
cfg.use_qk_norm,
cfg.qk_norm_type,
cfg._attn_implementation,
),
num_hidden_layers=cfg.num_hidden_layers,
head_dim=cfg.head_dim,
num_attention_heads=cfg.num_attention_heads,
)
self.graph_spec = static
return static
def can_use(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
) -> bool:
if not self.enabled or self.model.training or self.backbone.transformer.training:
return False
if next_input_ids.device.type != "cuda":
return False
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
return False
if not isinstance(past_key_values, _DepthDecodeStaticCache):
return False
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
return False
return self._depth_decode_spec().eligible
def _depth_decode_key(
self,
next_input_ids: torch.Tensor,
attention_bias: torch.Tensor,
) -> tuple[Any, ...]:
device = next_input_ids.device
return (
self._depth_decode_spec().cache_key_prefix,
device.type,
device.index,
self.model.lm_head.weight.dtype,
attention_bias.shape[-1],
)
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
emb = self.backbone.transformer.rotary_emb
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
def _depth_decode_pre_layer(
self,
layer_idx: int,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
residual = hidden_states
hidden_states = block.attn_norm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, attention.head_dim)
qkv = attention.att_proj(hidden_states)
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
value_states = value_states.view(hidden_shape)
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
if apply_qk_norm and not norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape)
if norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
return residual, query_states, key_states, value_states
def _depth_decode_pre0(
self,
token_ids: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
inputs_embeds = self.model._embed_base_tokens(token_ids)
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
def _depth_decode_post_layer(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
input_shape = residual.shape[:-1]
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
attn_output = attention.attn_out(attn_output)
hidden_states = residual + block.dropout(attn_output)
residual = hidden_states
hidden_states = block.ff_norm(hidden_states)
hidden_states = block.mlp(hidden_states)
hidden_states = residual + block.dropout(hidden_states)
return hidden_states
def _depth_decode_post_and_pre_next(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
def _depth_decode_last_post(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self.backbone.transformer.ln_f(hidden_states)
def _build_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
text_config = self.backbone.transformer.config
device = next_input_ids.device
dtype = self.model.lm_head.weight.dtype
static = self._depth_decode_spec()
num_layers = static.num_hidden_layers
head_dim = static.head_dim
max_cache_len = int(attention_bias.shape[-1])
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
sin = torch.empty_like(cos)
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
context_shape = (1, 1, static.num_attention_heads, head_dim)
token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(cos, sin, past_length=past_length)
pre_graph, pre_output = _capture_cuda_graph(
lambda: self._depth_decode_pre0(token_ids, cos, sin),
device,
)
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
post_graphs = []
for layer_idx in range(num_layers - 1):
stage = stages[-1]
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
graph, output = _capture_cuda_graph(
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
self._depth_decode_post_and_pre_next(
layer_idx,
stage.residual,
attn_context,
cos,
sin,
)
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
last_stage = stages[-1]
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
last_graph, last_output = _capture_cuda_graph(
lambda: self._depth_decode_last_post(
num_layers - 1,
last_stage.residual,
last_attn_context,
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
return _DepthDecodeCudaGraph(
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
pre_graph=pre_graph,
token_ids=token_ids,
cos=cos,
sin=sin,
positions=positions,
stages=tuple(stages),
post_graphs=tuple(post_graphs),
output=last_output,
)
def _get_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
key = self._depth_decode_key(next_input_ids, attention_bias)
decode_graph = self.graph
if decode_graph is None or decode_graph.cache_key != key:
decode_graph = self._build_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
self.graph = decode_graph
else:
decode_graph.token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
return decode_graph
def _run_depth_decode_attention_core(
self,
layer_idx: int,
stage: _DepthDecodeCudaGraphLayerStage,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
cache_position: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
attention = self.backbone.transformer.blocks[layer_idx].self_attn
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
stage.key,
stage.value,
layer_idx,
cache_kwargs,
)
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
attn_output = F.scaled_dot_product_attention(
stage.query,
key_states,
value_states,
attn_mask=attention_bias,
dropout_p=0.0,
is_causal=False,
)
return attn_output.transpose(1, 2)
def run(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
past_length: int,
) -> tuple[torch.Tensor, Cache]:
end = past_length + 1
decode_graph = self._get_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
cache_position = decode_graph.positions[past_length:end]
attention_bias_q = attention_bias[:, :, past_length:end, :end]
decode_graph.pre_graph.replay()
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
attn_context = self._run_depth_decode_attention_core(
layer_idx,
decode_graph.stages[layer_idx],
past_key_values=past_key_values,
attention_bias=attention_bias_q,
cache_position=cache_position,
cos=decode_graph.cos,
sin=decode_graph.sin,
)
post_graph.attn_context.copy_(attn_context)
post_graph.graph.replay()
return decode_graph.output, past_key_values
def _cuda_graph_tensor_signature(
tensor: torch.Tensor | None,
) -> tuple[Any, ...] | None:
if tensor is None:
return None
return (
tuple(tensor.shape),
tuple(tensor.stride()),
str(tensor.dtype),
str(tensor.device),
)
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
sig(context.cross_mask),
sig(context.self_mask),
sig(context.valid_action),
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
)
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return tuple(
(
sig(step.conditioning),
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
tuple(sig(t) for t in step.final_modulation),
)
for step in modulations
)
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
sig(inputs.trajectory),
_cuda_graph_context_signature(inputs.context),
_cuda_graph_modulation_signature(inputs.modulations),
sig(inputs.action_dim_is_pad),
int(steps),
)
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None:
return None
static = torch.empty_strided(
tuple(tensor.shape),
tuple(tensor.stride()),
device=tensor.device,
dtype=tensor.dtype,
)
static.copy_(tensor)
return static
def _clone_static_context(context: Any) -> Any:
rope_cache = None
if context.rope_cache is not None:
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
return context.__class__(
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
cross_mask=_clone_static_tensor(context.cross_mask),
self_mask=_clone_static_tensor(context.self_mask),
valid_action=_clone_static_tensor(context.valid_action),
rope_cache=rope_cache,
)
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
return tuple(
step.__class__(
conditioning=_clone_static_tensor(step.conditioning),
block_modulations=tuple(
tuple(_clone_static_tensor(t) for t in block_modulation)
for block_modulation in step.block_modulations
),
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
)
for step in modulations
)
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
return _ActionFlowInputs(
trajectory=_clone_static_tensor(inputs.trajectory),
context=_clone_static_context(inputs.context),
modulations=_clone_static_modulations(inputs.modulations),
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
)
def _copy_context_(dst: Any, src: Any) -> None:
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
dst_k.copy_(src_k)
dst_v.copy_(src_v)
if src.cross_mask is not None:
dst.cross_mask.copy_(src.cross_mask)
if src.self_mask is not None:
dst.self_mask.copy_(src.self_mask)
if src.valid_action is not None:
dst.valid_action.copy_(src.valid_action)
if src.rope_cache is not None:
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
dst_tensor.copy_(src_tensor)
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
dst.trajectory.copy_(src.trajectory)
_copy_context_(dst.context, src.context)
if src.action_dim_is_pad is not None:
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _capture_cuda_graph(
fn,
device: torch.device,
*,
after_warmup=None,
) -> tuple[torch.cuda.CUDAGraph, Any]:
warmup_stream = torch.cuda.Stream(device=device)
warmup_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(warmup_stream):
fn()
torch.cuda.current_stream(device).wait_stream(warmup_stream)
if after_warmup is not None:
after_warmup()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output = fn()
return graph, output
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,431 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
Processor class for MolmoAct2.
"""
from typing import Optional, Union
import dataclasses
import numpy as np
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
Unpack,
ProcessingKwargs,
ProcessorMixin,
)
from transformers.feature_extraction_utils import BatchFeature
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
from transformers.utils import logging
from transformers import AutoTokenizer
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
logger = logging.get_logger(__name__)
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
IM_START_TOKEN = f"<im_start>"
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
FRAME_START_TOKEN = f"<frame_start>"
IM_END_TOKEN = f"<im_end>"
FRAME_END_TOKEN = f"<frame_end>"
IM_COL_TOKEN = f"<im_col>"
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
IMAGE_TOKENS = [
IMAGE_PATCH_TOKEN,
IM_COL_TOKEN,
IM_START_TOKEN,
LOW_RES_IMAGE_START_TOKEN,
FRAME_START_TOKEN,
IM_END_TOKEN,
FRAME_END_TOKEN,
IMAGE_LOW_RES_TOKEN,
]
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
"""MolmoAct2 processor kwargs"""
images_kwargs: MolmoAct2ImagesKwargs
videos_kwargs: MolmoAct2VideoProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": True,
},
"videos_kwargs": {"return_metadata": True},
}
class MolmoAct2Processor(ProcessorMixin):
attributes = ["image_processor", "video_processor", "tokenizer"]
optional_attributes = [
"chat_template",
"time_mode",
"image_use_col_tokens",
"use_single_crop_col_tokens",
"use_single_crop_start_token",
"video_use_col_tokens",
"use_frame_special_tokens",
]
image_processor_class = "AutoImageProcessor"
video_processor_class = "AutoVideoProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor: MolmoAct2ImageProcessor = None,
video_processor: MolmoAct2VideoProcessor = None,
tokenizer: AutoTokenizer = None,
chat_template: str | None = None,
image_use_col_tokens: bool | None = True,
use_single_crop_col_tokens: bool | None = None,
use_single_crop_start_token: bool | None = True,
video_use_col_tokens: bool | None = False,
use_frame_special_tokens: bool | None = True,
**kwargs,
) -> None:
super().__init__(
image_processor,
video_processor,
tokenizer,
chat_template=chat_template,
)
self.image_use_col_tokens = image_use_col_tokens
self.use_single_crop_col_tokens = use_single_crop_col_tokens
self.use_single_crop_start_token = use_single_crop_start_token
self.video_use_col_tokens = video_use_col_tokens
self.use_frame_special_tokens = use_frame_special_tokens
self.image_placeholder_token = IMAGE_PROMPT
self.video_placeholder_token = VIDEO_PROMPT
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
def get_image_tokens(self, image_grid: np.ndarray):
resized_h, resized_w, height, width = image_grid
if int(height) == 0 or int(width) == 0:
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
]
return np.concatenate(joint)
per_row = np.full(width, IMAGE_PATCH_TOKEN)
if self.image_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [height]),
[IM_END_TOKEN],
]
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[image_start_token],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
] + joint
return np.concatenate(joint)
def get_video_string(
self,
video_grid: np.ndarray,
timestamps: np.ndarray,
):
if self.use_frame_special_tokens:
start_token_id = FRAME_START_TOKEN
end_token_id = FRAME_END_TOKEN
else:
start_token_id = IM_START_TOKEN
end_token_id = IM_END_TOKEN
num_frames, h, w = video_grid
video_string: str = ""
for frame_idx, frame_time in enumerate(timestamps):
# `per-frame-compact` time mode
prev_space = " " if frame_idx > 0 else ""
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
video_string += frame_prefix
per_row = np.full(w, IMAGE_PATCH_TOKEN)
if self.video_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
extra_tokens = np.tile(per_row, [h])
video_tokens = [
[start_token_id],
extra_tokens,
[end_token_id],
]
video_string += "".join(np.concatenate(video_tokens, 0))
return video_string
def insert_bos(
self,
input_ids: np.ndarray,
attention_mask: np.ndarray,
bos_token_id: int,
pad_token_id: int,
):
"""
Args:
input_ids: [B, S] array with left padding
attention_mask: [B, S] array (0 for pad, 1 for valid)
bos_token_id: int
pad_token_id: int
Returns:
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
attention_mask_out: same shape as input_ids_out
"""
need_to_expand = len(input_ids.shape) == 1
if need_to_expand:
input_ids = input_ids[None, :]
attention_mask = attention_mask[None, :]
B, S = input_ids.shape
# Handle zero-length sequence
if S == 0:
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
if bos_already_present:
if need_to_expand:
input_ids = input_ids[0]
attention_mask = attention_mask[0]
return input_ids, attention_mask
else:
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
tgt_idx = src_idx + 1 # shit right
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
# flatten valid_positions
flat_vals = input_ids[valid_mask]
flat_batch = batch_idx[valid_mask]
flat_tgt = tgt_idx[valid_mask]
new_input_ids[flat_batch, flat_tgt] = flat_vals
new_attention_mask[flat_batch, flat_tgt] = 1
insert_pos = first_valid_index
new_input_ids[np.arange(B), insert_pos] = bos_token_id
new_attention_mask[np.arange(B), insert_pos] = 1
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
def __call__(
self,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
images: ImageInput = None,
videos: VideoInput = None,
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
) -> BatchFeature:
"""
Args:
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
- `"timestamps"`: `np.ndarray` of shape (T,)
- `"sampled_fps"`: `float` (optional)
- `"sampling_augmentation"`: `str` (optional)
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
`BatchFeature`: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
Returned when `images` is not `None`.
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
Returned when `videos` is not `None`.
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
MolmoAct2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
image_grids = image_inputs["image_grids"]
else:
image_inputs = {}
image_grids = None
if videos is not None:
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grids = videos_inputs["video_grids"]
# If user has not requested video metadata, pop it
if "return_metadata" not in kwargs:
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
else:
videos_inputs = {}
video_grids = None
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
if image_grids is not None:
index = 0
for i in range(len(text)):
num_images = text[i].count(self.image_placeholder_token)
image_grids_i = image_grids[index : index + num_images]
for image_grid in image_grids_i:
image_tokens = self.get_image_tokens(image_grid)
image_string = "".join(image_tokens)
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
index += num_images
if video_grids is not None:
index = 0
for i in range(len(text)):
num_videos = text[i].count(self.video_placeholder_token)
assert num_videos in {0, 1}, "At most one video is supported for now"
video_grids_i = video_grids[index : index + num_videos]
metadata_i = video_metadata[index : index + num_videos]
for video_grid, metadata in zip(video_grids_i, metadata_i):
video_string = self.get_video_string(
video_grid,
metadata.timestamps,
)
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
index += num_videos
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
input_ids = text_inputs["input_ids"]
attention_mask = text_inputs["attention_mask"]
input_ids = np.array(input_ids)
attention_mask = np.array(attention_mask)
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
input_ids, attention_mask = self.insert_bos(
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
)
if return_mm_token_type_ids:
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
text_inputs["token_type_ids"] = token_type_ids.tolist()
text_inputs["input_ids"] = input_ids.tolist()
text_inputs["attention_mask"] = attention_mask.tolist()
return BatchFeature(
data={**text_inputs, **image_inputs, **videos_inputs},
tensor_type=return_tensors,
)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`list[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
MolmoAct2Processor.register_for_auto_class()
@@ -0,0 +1,997 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Video processor class for MolmoAct2"""
from functools import partial
import os
import warnings
from contextlib import redirect_stdout
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Union
from collections.abc import Callable
import numpy as np
import requests
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
SizeDict,
validate_kwargs,
)
from transformers.video_utils import (
VideoInput,
is_valid_video,
make_batched_videos,
make_batched_metadata,
VideoMetadata,
)
from transformers.processing_utils import Unpack, VideosKwargs
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import (
is_av_available,
is_decord_available,
is_torchcodec_available,
is_yt_dlp_available,
TensorType,
logging,
to_numpy,
)
logger = logging.get_logger(__name__)
MAX_VIDEO_FPS = 8
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
if len(image.shape) == 3:
is_video = False
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
else:
is_video = True
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
if is_video:
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
else:
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: ImageInput,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
pooling_w = image_pooling_w
pooling_h = image_pooling_h
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [h, w]
return (
image_grid,
batch_pixels_to_patches(resized, image_patch_size),
pooling_idx,
)
def get_candidate_target_fps(
video_fps: int | float,
sampling_fps: int | float,
max_fps: int | float = MAX_VIDEO_FPS,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
if video_fps % sampling_fps != 0:
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
def read_video_decord(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the Decord backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import from decord
import importlib
decord = importlib.import_module("decord")
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
duration = time_stamps[-1][1] - time_stamps[0][0]
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="decord",
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
target_timestamps = np.array(target_timestamps)
offset = time_stamps[0, 0]
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
ix = np.minimum(ix, len(time_stamps) - 1)
video = vr.get_batch(ix).asnumpy()
metadata.update(
{
"frames_indices": target_timestamps * video_fps,
"height": video.shape[1],
"width": video.shape[2],
}
)
return video, metadata
def read_video_torchcodec(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using torchcodec decoder.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
torchcodec = importlib.import_module("torchcodec")
decoder = torchcodec.decoders.VideoDecoder(
video_path,
# Interestingly `exact` mode takes less than approximate when we load the whole video
seek_mode="exact",
# Allow FFmpeg decide on the number of threads for efficiency
num_ffmpeg_threads=0,
)
# If the first frame starts at > 0, we effectively clip the video starting at that time
# since (most) video players would also skip to that time
time_offset = decoder.metadata.begin_stream_seconds_from_content
# Note this duration does assume we started playing at `time_offset`
duration = decoder.metadata.duration_seconds
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
duration=duration,
video_backend="torchcodec",
height=decoder.metadata.height,
width=decoder.metadata.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
# out-of-bounds, to handle this we sanity check then clip them
assert all(x >= 0 for x in target_timestamps)
assert all(x < duration + 1e-6 for x in target_timestamps)
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
# exact boundary value, we should still get the first/last frame anyway
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
# Note we avoid using numpy ops here to reduce floating precision issues
timestamps = [x + time_offset for x in target_timestamps]
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
video = (
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
) # Convert to THWC format
target_timestamps = np.array(target_timestamps)
metadata.frames_indices = target_timestamps * metadata.fps
return video, metadata
def read_video_pyav(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the PyAV backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
av = importlib.import_module("av")
with av.open(video_path) as container:
video_stream = container.streams.video[0]
fps = video_stream.average_rate or video_stream.guessed_rate
it = container.decode(video=0)
frames = list(it)
stream = container.streams.video[0]
start = frames[0].pts * stream.time_base
container_end = stream.duration
if container_end is not None:
container_end *= stream.time_base
if container_end is None or container_end < frames[-1].pts:
# Some problem with stream duration, so use the frame PTS directly
# and guess the duration of the last frame
end = frames[-1].pts * stream.time_base + 1 / fps
else:
end = container_end
duration = float(end - start)
metadata = VideoMetadata(
total_num_frames=len(frames),
fps=float(fps),
duration=float(duration),
video_backend="pyav",
height=video_stream.height,
width=video_stream.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
offset = float(start)
target_timestamps = np.array(target_timestamps)
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
indices = np.minimum(indices, len(end_time_stamps) - 1)
video = np.stack(
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
axis=0,
)
metadata.frames_indices = target_timestamps * fps
return video, metadata
VIDEO_DECODERS = {
"decord": read_video_decord,
"torchcodec": read_video_torchcodec,
"pyav": read_video_pyav,
}
def load_video(
video: VideoInput,
backend: str = "decord",
sample_timestamps_fn: Callable | None = None,
**kwargs,
):
"""
Loads `video` to a numpy array.
Args:
video (`VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
backend (`str`, *optional*, defaults to `"decord"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
"""
# Early exit if provided an array or `PIL` frames
if not isinstance(video, str):
metadata = [None] * len(video)
return video, metadata
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
import importlib
yt_dlp = importlib.import_module("yt_dlp")
buffer = BytesIO()
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video, timeout=10).content)
elif os.path.isfile(video):
file_obj = video
else:
raise TypeError(
"Incorrect format used for video. Should be an url linking to an video or a local path."
)
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend == "opencv":
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
if (
(not is_decord_available() and backend == "decord")
or (not is_torchcodec_available() and backend == "torchcodec")
or (not is_av_available() and backend == "pyav")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)
video_decoder = VIDEO_DECODERS[backend]
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
return video, metadata
def get_target_fps(
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: tuple[float],
) -> float:
"""
Get the target fps that best spans the video and has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if "uniform" in frame_sample_mode:
if num_frames_sampled_at_fps > max_frames:
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames, choose the one with higher density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
if selected_target_fps is None:
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
patch_size: int | None
pooling_size: list[int] | None
frame_sample_mode: str | None
max_fps: int | None
sampling_fps: int | None
class MolmoAct2VideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BILINEAR
size = {"height": 378, "width": 378}
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
patch_size = 14
pooling_size = [3, 3]
do_sample_frames = True
frame_sample_mode = "uniform_last_frame"
max_fps = 2
sampling_fps = 2
valid_kwargs = MolmoAct2VideoProcessorKwargs
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
super().__init__(**kwargs)
if self.size is not None and (
self.size.get("height", None) is None or self.size.get("width", None) is None
):
raise ValueError("size must contain 'height' and 'width' keys.")
def _further_process_kwargs(
self,
size: SizeDict | None = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if size is not None and ("height" not in size or "width" not in size):
raise ValueError("size must contain 'height' and 'width' keys.")
return super()._further_process_kwargs(size=size, **kwargs)
def sample_times(
self,
metadata: VideoMetadata,
frame_sample_mode: str,
num_frames: int,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Time-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
man_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
duration = metadata.duration or metadata.total_num_frames / metadata.fps
if frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
# Try larger and larger FPSs until we hit one that can't span the video
target_fps = candidate_target_fps[0]
for candidate_fps in candidate_target_fps[1:]:
if num_frames / candidate_fps < duration:
break
target_fps = candidate_fps
times = np.arange(0, num_frames) / target_fps
times = times[times < duration]
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
else:
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= num_frames
else:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
return times
else:
raise NotImplementedError(frame_sample_mode)
def sample_frames(
self,
metadata: VideoMetadata,
frame_sample_mode: str | None = None,
num_frames: int | None = None,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Frame-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
max_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
total_num_frames = metadata.total_num_frames
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
duration = total_num_frames / metadata.fps
if total_num_frames <= 2:
return np.arange(total_num_frames).astype(int)
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(metadata.fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
return indices
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
elif frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
selected_target_fps = get_target_fps(
metadata.fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
metadata.fps,
)
return indices
else:
raise NotImplementedError(frame_sample_mode)
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
"""
Convert a single or a list of urls into the corresponding `np.array` objects.
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
raise ImportError(
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
)
if is_decord_available():
backend = "decord"
elif is_torchcodec_available():
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `torchcodec`."
)
backend = "torchcodec"
else:
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `PyAV`."
)
backend = "pyav"
if isinstance(video_url_or_urls, list):
return list(
zip(
*[
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
for x in video_url_or_urls
]
)
)
else:
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
def _decode_and_sample_videos(
self,
videos: VideoInput,
video_metadata: VideoMetadata | dict,
do_sample_frames: bool | None = None,
sample_indices_fn: Callable | None = None,
sample_timestamps_fn: Callable | None = None,
):
"""
Decode input videos and sample frames if needed.
"""
videos = make_batched_videos(videos)
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
# Framed-based sampling if an array video is passed
# Otherwise, time-based sampling with decoding
if is_valid_video(videos[0]) and do_sample_frames:
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
sampled_videos = []
sampled_metadata = []
for video, metadata in zip(videos, video_metadata):
indices = sample_indices_fn(metadata=metadata)
metadata.frames_indices = indices
sampled_videos.append(video[indices])
sampled_metadata.append(metadata)
videos = sampled_videos
video_metadata = sampled_metadata
elif not is_valid_video(videos[0]):
if sample_indices_fn is None:
logger.warning(
"do_sample_frames is False, but video array is not provided: "
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
)
if isinstance(videos[0], list):
raise ValueError("A list of images is not supported for video input!")
else:
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
return videos, video_metadata
def _prepare_input_videos(
self,
videos: VideoInput,
**kwargs,
) -> list[np.ndarray]:
processed_videos = [to_numpy(video) for video in videos]
return processed_videos
def preprocess(
self,
videos: VideoInput,
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
do_sample_frames = kwargs.pop("do_sample_frames")
video_metadata = kwargs.pop("video_metadata")
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
sample_timestamps_fn = partial(self.sample_times, **kwargs)
videos, video_metadata = self._decode_and_sample_videos(
videos,
video_metadata=video_metadata,
do_sample_frames=do_sample_frames,
sample_indices_fn=sample_indices_fn,
sample_timestamps_fn=sample_timestamps_fn,
)
videos = self._prepare_input_videos(videos=videos)
kwargs = self._further_process_kwargs(**kwargs)
return_metadata = kwargs.pop("return_metadata")
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
if return_metadata:
preprocessed_videos["video_metadata"] = video_metadata
return preprocessed_videos
def _preprocess(
self,
videos: list[np.ndarray],
size: SizeDict | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess a video for the model.
Args:
videos (`VideoInput`):
Video to preprocess.
size (`SizeDict`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values_videos`: The preprocessed videos.
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
- `video_grids`: The video grids.
"""
if size.height is None or size.width is None:
raise ValueError("size must contain 'height' and 'width' keys.")
base_image_input_size = [size.height, size.width]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
for video in videos:
all_crops = []
pooled_patches_idx = []
for frame in video:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
frame,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
)
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
pooled_patches_idx.append(pooled_idx_with_offset)
all_crops.append(crops)
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
all_crops = np.concatenate(all_crops, 0)
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
batch_grids.append(video_grid)
batch_crops.append(all_crops)
batch_pooled_patches_idx.append(pooled_patches_idx)
video_grids = np.stack(batch_grids, 0)
pixel_values_videos = np.concatenate(batch_crops, 0)
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
data = dict(
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2VideoProcessor.register_for_auto_class()
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+85 -5
View File
@@ -29,6 +29,7 @@ from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.__version__ import __version__
from lerobot.configs import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.utils.hub import HubMixin
@@ -38,6 +39,67 @@ from .utils import log_model_loading_keys
T = TypeVar("T", bound="PreTrainedPolicy")
def _build_card_context(
cfg: TrainPipelineConfig | None,
dataset_repo_id: str | None,
input_features: dict | None,
output_features: dict | None,
) -> dict:
"""Collect optional data for the model-card template.
Returns plain values only (no Markdown) the template in
``lerobot/templates/lerobot_modelcard_template.md`` decides how and whether to show
each one. Everything is best-effort: anything unavailable is left empty/None and the
template simply skips that section, so this never breaks a Hub push.
"""
context = {
"training": None,
"input_features": input_features or {},
"output_features": output_features or {},
"dataset": None,
"robot_type": None,
"cameras": [],
}
if cfg is not None:
optimizer = getattr(cfg, "optimizer", None)
context["training"] = {
"steps": cfg.steps,
"batch_size": cfg.batch_size,
"seed": cfg.seed,
"optimizer": getattr(optimizer, "type", None) if optimizer else None,
"lr": getattr(optimizer, "lr", None) if optimizer else None,
"lerobot_version": __version__,
}
if dataset_repo_id:
dataset_cfg = getattr(cfg, "dataset", None)
try:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(
dataset_repo_id,
root=getattr(dataset_cfg, "root", None),
revision=getattr(dataset_cfg, "revision", None),
)
context["dataset"] = {
"repo_id": dataset_repo_id,
"episodes": meta.total_episodes,
"frames": meta.total_frames,
"fps": meta.fps,
"tasks": [str(task) for task in meta.tasks.index],
}
context["robot_type"] = meta.robot_type
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
logging.warning(
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
f"omitted from the model card. ({e})"
)
return context
class ActionSelectKwargs(TypedDict, total=False):
noise: Tensor | None
@@ -228,7 +290,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
)
card.save(str(saved_path / "README.md"))
@@ -246,9 +308,20 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
self,
dataset_repo_id: str,
model_type: str,
license: str | None,
tags: list[str] | None,
cfg: TrainPipelineConfig | None = None,
) -> ModelCard:
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
base_model_mapping = {
"smolvla": "lerobot/smolvla_base",
"pi0": "lerobot/pi0_base",
"pi05": "lerobot/pi05_base",
"pi0_fast": "lerobot/pi0fast-base",
"xvla": "lerobot/xvla-base",
}
card_data = ModelCardData(
license=license or "apache-2.0",
@@ -257,13 +330,20 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
model_name=model_type,
datasets=dataset_repo_id,
base_model=base_model,
base_model=base_model_mapping.get(model_type),
)
context = _build_card_context(
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
)
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
context["base_model"] = base_model_mapping.get(model_type)
template_card = (
files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8")
)
card = ModelCard.from_template(card_data, template_str=template_card)
card = ModelCard.from_template(card_data, template_str=template_card, **context)
card.validate()
return card
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_vla_jepa_README.md
+23
View File
@@ -0,0 +1,23 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"make_vla_jepa_pre_post_processors",
]
@@ -0,0 +1,337 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions
@@ -0,0 +1,154 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
gripper_dim: int = 6
gripper_threshold: float = 0.5
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
@@ -0,0 +1,629 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if reinitialized:
logging.warning(
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model
@@ -0,0 +1,155 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes a gripper dimension after unnormalization.
Maps continuous value to {-1, 1}: > threshold -1, <= threshold 1 (matches starVLA convention).
Only applied when action has more dimensions than gripper_dim.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -0,0 +1,117 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)
@@ -0,0 +1,418 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)
+279 -55
View File
@@ -32,7 +32,6 @@ from __future__ import annotations
import importlib
import json
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
_serialized_state_filenames: tuple[str | None, ...] | None = field(
default=None,
init=False,
repr=False,
)
def __call__(self, data: TInput) -> TOutput:
"""Processes input data through the full pipeline.
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
transition = processor_step(transition)
yield transition
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `HubMixin`'s saving mechanism.
def _get_sanitized_name(self) -> str:
"""Return a filename-safe version of the pipeline name.
This method does the actual saving work and is called by HubMixin.save_pretrained.
Returns:
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
"""
config_filename = kwargs.pop("config_filename", None)
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
@staticmethod
def _get_state_filename(
*,
step_index: int,
registry_name: str | None,
sanitized_name: str,
) -> str:
"""Return the safetensors filename for one stateful processor step.
if config_filename is None:
config_filename = f"{sanitized_name}.json"
Args:
step_index: The index of the processor step in this pipeline.
registry_name: The registered processor step name, if available.
sanitized_name: The filename-safe pipeline name.
config: dict[str, Any] = {
Returns:
The state filename used by the existing disk serialization format.
"""
if registry_name:
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
return f"{sanitized_name}_step_{step_index}.safetensors"
@staticmethod
def _get_state_key(state_filename: str) -> str:
"""Return the in-memory state key for a serialized state filename.
Args:
state_filename: The `.safetensors` filename from the serialized config.
Returns:
The state key used by the in-memory pipeline state dictionary.
"""
return state_filename.removesuffix(".safetensors")
@staticmethod
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
"""Return serialized state filenames in step order.
Args:
loaded_config: A validated processor pipeline config.
Returns:
A tuple containing each step's serialized state filename, or None for stateless steps.
"""
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
"""Return expected state filenames in step order for `load_state_dict()`.
Returns:
The preserved serialized state filenames when available, otherwise filenames derived from
current non-empty step state.
"""
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
self.steps
):
return self._serialized_state_filenames
sanitized_name = self._get_sanitized_name()
state_filenames: list[str | None] = []
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
state_filenames.append(None)
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filenames.append(
self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
)
return tuple(state_filenames)
def get_config(self) -> dict[str, Any]:
"""Return the JSON-serializable pipeline configuration.
Returns:
A dictionary with the same content that `save_pretrained()` writes as JSON.
"""
sanitized_name = self._get_sanitized_name()
pipeline_config: dict[str, Any] = {
"name": self.name,
"steps": [],
}
# Iterate through each step to build its configuration entry.
for step_index, processor_step in enumerate(self.steps):
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
# Prefer registry name for portability, otherwise fall back to full class path.
if registry_name:
step_entry["registry_name"] = registry_name
else:
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
# Save step configuration if `get_config` is implemented.
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
step_entry["config"] = processor_step.get_config()
# Save step state if `state_dict` is implemented and returns a non-empty dict.
if hasattr(processor_step, "state_dict"):
state = processor_step.state_dict()
if state:
# Clone tensors to avoid modifying the original state.
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
step_state_dict = processor_step.state_dict()
if step_state_dict:
step_entry["state_file"] = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
# Create a unique filename for the state file.
if registry_name:
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
pipeline_config["steps"].append(step_entry)
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
return pipeline_config
config["steps"].append(step_entry)
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
"""Return pipeline state tensors grouped by state key.
# Write the main configuration JSON file.
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
Returns:
A dictionary mapping suffixless state keys to cloned step state dictionaries.
"""
sanitized_name = self._get_sanitized_name()
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filename = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
state_key = self._get_state_key(state_filename)
pipeline_state_dict[state_key] = {
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
}
return pipeline_state_dict
def load_state_dict(
self,
state_dict: dict[str, dict[str, torch.Tensor]],
) -> None:
"""Load pipeline state tensors into the existing steps.
Args:
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
Raises:
KeyError: If loading finds missing expected state or unexpected extra state.
"""
expected_state_filenames = self._get_state_filenames_for_loading()
used_state_keys: set[str] = set()
for step_index, (processor_step, state_filename) in enumerate(
zip(self.steps, expected_state_filenames, strict=True)
):
if state_filename is None:
continue
state_key = self._get_state_key(state_filename)
if state_key not in state_dict:
raise KeyError(
f"Missing state key '{state_key}' for processor step {step_index}. "
f"Available state keys: {sorted(state_dict.keys())}"
)
processor_step.load_state_dict(state_dict[state_key])
used_state_keys.add(state_key)
unexpected_state_keys = set(state_dict) - used_state_keys
if unexpected_state_keys:
expected_state_key_set = {
self._get_state_key(state_filename)
for state_filename in expected_state_filenames
if state_filename is not None
}
raise KeyError(
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
f"Expected state keys: {sorted(expected_state_key_set)}"
)
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
"""Internal method to comply with `HubMixin`'s saving mechanism.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
config_filename = kwargs.pop("config_filename", None)
sanitized_name = self._get_sanitized_name()
if config_filename is None:
config_filename = f"{sanitized_name}.json"
pipeline_config = self.get_config()
pipeline_state_dict = self.state_dict()
for state_key, step_state_dict in pipeline_state_dict.items():
state_filename = f"{state_key}.safetensors"
save_file(step_state_dict, save_directory / state_filename)
with open(save_directory / config_filename, "w") as file_pointer:
json.dump(pipeline_config, file_pointer, indent=2)
def save_pretrained(
self,
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
cls._validate_overrides_used(validated_overrides, loaded_config)
# 5. Construct and return the final pipeline instance
return cls(
pipeline = cls(
steps=steps,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
return pipeline
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[TInput], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
) -> DataProcessorPipeline[TInput, TOutput]:
"""Build a pipeline from an in-memory config and optional state tensors.
Args:
config: A config dictionary with the same structure as the saved processor JSON.
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
overrides: Optional constructor overrides keyed by registry name or class name.
to_transition: Optional converter from input data to `EnvTransition`.
to_output: Optional converter from `EnvTransition` to output data.
Returns:
A processor pipeline built from the config and optional state.
"""
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
cls._validate_overrides_used(remaining_override_keys, config)
pipeline = cls(
steps=steps,
name=config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
if state_dict is not None:
pipeline.load_state_dict(state_dict)
return pipeline
@classmethod
def _load_config(
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
) from e
@classmethod
def _validate_loaded_config(
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
) -> None:
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
"""Validate that a config was loaded and is a valid processor config.
This method validates processor config format with intelligent migration detection:
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Args:
model_id: The model identifier (used for migration detection)
loaded_config: The loaded config dictionary (guaranteed non-None)
loaded_config: The loaded config value to validate (may be non-dict)
config_filename: The config filename that was loaded (for error messages)
Raises:
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
model_id,
f"Config file '{config_filename}' is not a valid processor configuration",
)
loaded_config_description = (
list(loaded_config.keys())
if isinstance(loaded_config, dict)
else type(loaded_config).__name__
)
raise ValueError(
f"Config file '{config_filename}' is not a valid processor configuration. "
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
)
@classmethod
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
ImportError: If a step class cannot be imported or found in registry
ValueError: If a step cannot be instantiated with its configuration
"""
steps: list[ProcessorStep] = []
override_keys = set(overrides.keys())
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
for step_entry in loaded_config["steps"]:
# 1. Get step class and key
step_class, step_key = cls._resolve_step_class(step_entry)
# 2. Instantiate step with overrides
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
# 3. Load step state if available
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
# 4. Track used overrides
if step_key in override_keys:
override_keys.discard(step_key)
return steps, remaining_override_keys
steps.append(step_instance)
@classmethod
def _build_steps_from_config(
cls,
loaded_config: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[list[ProcessorStep], set[str]]:
"""Build processor steps from config without loading tensor state.
return steps, override_keys
Args:
loaded_config: The loaded processor configuration.
overrides: User-provided constructor overrides keyed by step key.
Returns:
A tuple containing instantiated steps and override keys that did not match a step.
"""
processor_steps: list[ProcessorStep] = []
remaining_override_keys = set(overrides.keys())
for step_entry in loaded_config["steps"]:
step_class, step_key = cls._resolve_step_class(step_entry)
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
if step_key in remaining_override_keys:
remaining_override_keys.discard(step_key)
processor_steps.append(processor_step)
return processor_steps, remaining_override_keys
@classmethod
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
return True
@classmethod
def _is_processor_config(cls, config: dict) -> bool:
def _is_processor_config(cls, config: Any) -> bool:
"""Check if config follows DataProcessorPipeline format.
This method validates the processor configuration structure:
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Returns:
True if config follows valid DataProcessorPipeline format, False otherwise
"""
if not isinstance(config, dict):
return False
# Must have a "steps" field with a list of step configurations
if not isinstance(config.get("steps"), list):
return False
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
return actions
@ProcessorStepRegistry.register("delta_actions_processor")
@ProcessorStepRegistry.register("relative_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
+2
View File
@@ -20,12 +20,14 @@ from .factory import (
make_reward_pre_post_processors as make_reward_pre_post_processors,
)
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
__all__ = [
# Configuration classes
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
"TOPRewardConfig",
# Base class
+16 -2
View File
@@ -25,6 +25,7 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig
@@ -38,7 +39,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
Args:
name: The name of the reward model. Supported names are "reward_classifier",
"sarm", "topreward".
"sarm", "robometer", "topreward".
Returns:
The reward model class corresponding to the given name.
@@ -54,6 +55,10 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "robometer":
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
return RobometerRewardModel
elif name == "topreward":
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
@@ -74,7 +79,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
Args:
reward_type: The type of the reward model. Supported types include
"reward_classifier", "sarm", "topreward".
"reward_classifier", "sarm", "robometer", "topreward".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -87,6 +92,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RewardClassifierConfig(**kwargs)
elif reward_type == "sarm":
return SARMConfig(**kwargs)
elif reward_type == "robometer":
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
else:
@@ -168,6 +175,13 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(reward_cfg, RobometerConfig):
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
return make_robometer_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, TOPRewardConfig):
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_robometer import RobometerConfig
from .modeling_robometer import RobometerRewardModel
from .processor_robometer import make_robometer_pre_post_processors
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]
@@ -0,0 +1,320 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute per-frame Robometer progress and success curves for a LeRobot dataset.
For each episode, builds per-frame sub-samples using the frame-steps
strategy from the Robometer eval server: for each original frame ``t``,
linspace-subsample ``[0, t]`` into ``K`` frames (default 4, matching
``NUM_SUBSAMPLED_FRAMES`` in the eval server), run one forward through
the Robometer processor + model, and keep the last-frame progress value.
All sub-samples are the same size ``K`` so they batch cleanly.
The parquet uses the same schema as SARM's
:mod:`lerobot.rewards.sarm.compute_rabc_weights` so existing consumers
:class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
``progress_sparse``) and the progress-overlay script in
``examples/dataset/create_progress_videos.py`` work without modification.
Usage:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes with batching
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
from lerobot.types import TransitionKey
DEFAULT_OUTPUT_FILENAME = "robometer_progress.parquet"
# Upstream Robometer eval server uses K=4 for frame-steps sub-samples.
DEFAULT_NUM_SUBSAMPLED_FRAMES = 4
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
"""Read ``reward_model_path`` from parquet metadata if available."""
if not parquet_path.exists():
return None
try:
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
if metadata and b"reward_model_path" in metadata:
return metadata[b"reward_model_path"].decode()
except Exception: # nosec B110
return None
return None
def _resolve_task(sample: dict[str, Any], default: str) -> str:
"""Best-effort task extraction from a dataset sample."""
task = sample.get("task")
if isinstance(task, str) and task:
return task
return default
def _build_subsample_indices(num_frames: int, num_subsampled_frames: int) -> list[np.ndarray]:
"""Frame-steps linspace expansion.
For each ``t in [0, num_frames - 1]`` returns ``num_subsampled_frames``
indices from ``np.linspace(0, t, num_subsampled_frames)`` the first
and last frames are always included. Each entry is a fixed-size array
so the model can batch them.
"""
return [np.linspace(0, t, num_subsampled_frames).round().astype(np.int64) for t in range(num_frames)]
def compute_robometer_progress(
dataset_repo_id: str,
reward_model_path: str,
output_path: str | None = None,
device: str = "cuda",
batch_size: int = 32,
num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES,
episodes: list[int] | None = None,
image_key: str | None = None,
) -> Path:
"""Run Robometer over a dataset and write per-frame progress + success."""
logging.info(f"Loading Robometer: {reward_model_path}")
config = RobometerConfig(pretrained_path=reward_model_path, device=device)
if image_key is not None:
config.image_key = image_key
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
model.to(device).eval()
encoder = RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=num_subsampled_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
)
image_key = config.image_key
logging.info(f"Loading dataset: {dataset_repo_id}")
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
logging.info(f"Processing {len(episode_indices)} episode(s)")
all_index: list[int] = []
all_episode: list[int] = []
all_frame: list[int] = []
all_progress: list[float] = []
for episode_idx in tqdm(episode_indices, desc="Episodes"):
ep = dataset.meta.episodes[episode_idx]
ep_start = int(ep["dataset_from_index"])
ep_end = int(ep["dataset_to_index"])
num_frames = ep_end - ep_start
if num_frames <= 0:
continue
first_sample = dataset[ep_start]
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
ep_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
sub_indices = _build_subsample_indices(num_frames, num_subsampled_frames)
progress_per_frame = np.zeros(num_frames, dtype=np.float32)
for start in tqdm(range(0, num_frames, batch_size), desc=f" Ep {episode_idx}", leave=False):
end = min(start + batch_size, num_frames)
frames_batch = torch.stack([ep_frames[sub_indices[i]] for i in range(start, end)])
transition = {
TransitionKey.OBSERVATION: {image_key: frames_batch},
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
}
encoded = encoder(transition)
obs = encoded[TransitionKey.OBSERVATION]
batch = {
key: value.to(device) if isinstance(value, torch.Tensor) else value
for key, value in obs.items()
}
with torch.no_grad():
rewards = model.compute_reward(batch)
progress_per_frame[start:end] = rewards.cpu().numpy()
for local in range(num_frames):
all_index.append(ep_start + local)
all_episode.append(episode_idx)
all_frame.append(local)
all_progress.append(float(progress_per_frame[local]))
if device.startswith("cuda"):
torch.cuda.empty_cache()
table = pa.table(
{
"index": np.asarray(all_index, dtype=np.int64),
"episode_index": np.asarray(all_episode, dtype=np.int64),
"frame_index": np.asarray(all_frame, dtype=np.int64),
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
}
).replace_schema_metadata({b"reward_model_path": reward_model_path.encode()})
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out)
logging.info(f"Saved {len(table)} frame values to {out}")
progress_arr = np.asarray(all_progress, dtype=np.float32)
if progress_arr.size:
logging.info(
f"Progress: mean={float(progress_arr.mean()):.4f}, "
f"std={float(progress_arr.std()):.4f}, "
f"min={float(progress_arr.min()):.4f}, "
f"max={float(progress_arr.max()):.4f}"
)
return out
def main():
parser = argparse.ArgumentParser(
description="Compute per-frame Robometer progress curves for RA-BC weighting.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes, smaller batches for memory-constrained GPUs
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
""",
)
parser.add_argument(
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
)
parser.add_argument(
"--reward-model-path", type=str, default=None, help="Robometer checkpoint repo id or local path."
)
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
parser.add_argument(
"--batch-size", type=int, default=32, help="Sub-samples per Qwen forward (default: 32)."
)
parser.add_argument(
"--num-subsampled-frames",
type=int,
default=DEFAULT_NUM_SUBSAMPLED_FRAMES,
help=f"Frames per sub-sample (default: {DEFAULT_NUM_SUBSAMPLED_FRAMES}, matches eval server).",
)
parser.add_argument(
"--episodes", type=int, nargs="+", default=None, help="Process only these episode indices."
)
parser.add_argument(
"--image-key", type=str, default=None, help="Image observation key (default: from config)."
)
parser.add_argument(
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
reward_model_path = args.reward_model_path
if reward_model_path is None:
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
parquet_path = Path(temp_dataset.root) / DEFAULT_OUTPUT_FILENAME
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
if reward_model_path:
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
else:
raise ValueError(
"--reward-model-path is required (no existing parquet with model metadata found)."
)
output_path = compute_robometer_progress(
dataset_repo_id=args.dataset_repo_id,
reward_model_path=reward_model_path,
output_path=args.output_path,
device=args.device,
batch_size=args.batch_size,
num_subsampled_frames=args.num_subsampled_frames,
episodes=args.episodes,
image_key=args.image_key,
)
print(f"\nRobometer progress saved to: {output_path}")
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = DEFAULT_OUTPUT_FILENAME
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(output_path),
path_in_repo=hub_path,
repo_id=args.dataset_repo_id,
repo_type="dataset",
)
print(
"Successfully uploaded to: "
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
)
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
print(" rabc_head_mode: sparse")
else:
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: {output_path}")
print(" rabc_head_mode: sparse")
if __name__ == "__main__":
main()
@@ -0,0 +1,158 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.rewards import RewardModelConfig
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoTokenizer
else:
AutoConfig = None # type: ignore[assignment]
AutoTokenizer = None # type: ignore[assignment]
# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time.
# The order is part of the data contract: upstream resized ``embed_tokens``
# after adding these tokens in this exact order, so changing the set or order
# would silently misalign the saved embedding rows with their token ids.
# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream
# heads (never read at inference) but still occupy rows the checkpoint expects.
ROBOMETER_SPECIAL_TOKENS = (
"<|split_token|>",
"<|reward_token|>",
"<|pref_token|>",
"<|sim_token|>",
"<|prog_token|>",
)
@RewardModelConfig.register_subclass("robometer")
@dataclass
class RobometerConfig(RewardModelConfig):
"""Configuration for the Robometer reward model."""
pretrained_path: str | None = "lerobot/Robometer-4B"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
reward_output: str = "progress" # "progress" or "success"
success_threshold: float = 0.5
license: str | None = "apache-2.0"
tags: list[str] | None = field(
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
)
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
torch_dtype: str = "bfloat16"
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
average_temporal_patches: bool = True
frame_pooling: str = "mean" # "mean" | "boundary" | "attention"
frame_pooling_attn_temperature: float = 1.0
progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete"
progress_discrete_bins: int = 10
# Serialised Qwen backbone config (post-resize). Always populated by
# ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it
# is non-empty after construction. Saved into ``config.json`` automatically
# by the base ``_save_pretrained``.
vlm_config: dict[str, Any] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self) -> None:
super().__post_init__()
if self.reward_output not in {"progress", "success"}:
raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}")
if self.max_frames is not None and self.max_frames < 1:
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
if self.frame_pooling not in {"mean", "boundary", "attention"}:
raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}")
if self.frame_pooling_attn_temperature <= 0:
raise ValueError("frame_pooling_attn_temperature must be > 0")
if self.progress_loss_type not in {"l1", "l2", "discrete"}:
raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}")
if self.use_per_frame_progress_token and not self.use_multi_image:
raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True")
if self.image_key not in self.input_features:
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
# Deterministically populate ``vlm_config`` so it is non-empty after
# construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives
# ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize
# vocab the published ``Robometer-4B`` checkpoint was saved with.
if not self.vlm_config:
require_package("transformers", extra="robometer")
vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict()
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
text_config = vlm.get("text_config")
if not isinstance(text_config, dict):
raise ValueError(
f"Backbone config for {self.base_model_id!r} has no nested `text_config`; "
"Robometer expects a Qwen-VL-style config."
)
text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)
self.vlm_config = vlm
@property
def use_discrete_progress(self) -> bool:
"""Whether the progress head outputs distribution logits over bins."""
return self.progress_loss_type.lower() == "discrete"
@property
def vlm_backbone_config(self):
"""Reconstruct the Qwen backbone config from :attr:`vlm_config`."""
require_package("transformers", extra="robometer")
config_dict = deepcopy(self.vlm_config)
model_type = config_dict.pop("model_type", None)
if model_type is None:
raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config")
return AutoConfig.for_model(model_type, **config_dict)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> None:
return None
@property
def reward_delta_indices(self) -> None:
return None
def validate_features(self) -> None:
if self.image_key not in self.input_features:
raise ValueError(f"Robometer requires image input feature {self.image_key!r}")
@@ -0,0 +1,481 @@
# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons.
Paper: https://arxiv.org/abs/2603.02115
Project: https://robometer.github.io
Original code: https://github.com/aliang8/robometer
Model: https://huggingface.co/robometer/Robometer-4B
Robometer is a general-purpose, video-language-input reward model built on
``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction
objective:
- A frame-level progress loss anchoring reward magnitude on expert data.
- A trajectory-comparison preference loss imposing global ordering constraints
across trajectories sharing the same instruction.
To support downstream RL it also predicts a frame-level binary success. The
training prompt inserts three learnable tokens:
- ``<|prog_token|>`` after each frame to read per-frame progress and success.
- ``<|pref_token|>`` at the end to read pairwise preference (training-only).
- ``<|split_token|>`` between two trajectories in preference samples
(training-only).
Progress is modeled as a categorical distribution over ``progress_discrete_bins``
uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate
is recovered as the softmax-weighted mean of those centers see
:func:`convert_bins_to_continuous`.
This LeRobot port is **inference-only**: the preference head is preserved in
the state dict for byte-equivalence with the published ``Robometer-4B``
checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`,
which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd
success probability depending on :attr:`RobometerConfig.reward_output`.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import torch
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModelForImageTextToText
else:
AutoModelForImageTextToText = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
# Namespace for Robometer's pre-encoded Qwen-VL observation tensors.
ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer."
ROBOMETER_QWEN_INPUT_KEYS = (
"input_ids",
"attention_mask",
"pixel_values",
"pixel_values_videos",
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"mm_token_type_ids",
)
ROBOMETER_METADATA_KEYS = (
"prog_token_id",
"vision_start_token_id",
"vision_end_token_id",
"video_merge_size",
)
ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS
def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor:
"""Collapse per-bin logits into a single value in ``[0, 1]``.
The discrete progress head outputs ``num_bins`` logits per frame. Bins are
evenly spaced centers in ``[0, 1]``; the continuous prediction is the
softmax-weighted mean of those centers.
"""
bin_probs = torch.softmax(bin_logits, dim=-1)
num_bins = bin_logits.shape[-1]
bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype)
return (bin_probs * bin_centers).sum(dim=-1)
def _squeeze_last_safe(x: Tensor) -> Tensor:
"""Drop a trailing singleton dim only when present."""
return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x
def _torch_dtype(name: str) -> torch.dtype:
dtype = getattr(torch, name, None)
if isinstance(dtype, torch.dtype):
return dtype
raise ValueError(f"Unknown torch dtype: {name!r}")
class RobometerPredictionHead(nn.Sequential):
"""Small MLP head used for Robometer's progress / success / preference outputs."""
def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None:
layers: list[nn.Module] = [
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_size),
]
if with_sigmoid:
layers.append(nn.Sigmoid())
super().__init__(*layers)
def decode_progress_outputs(
progress_logits: Tensor | None,
success_logits: Tensor | None,
*,
is_discrete_mode: bool,
) -> dict[str, list[list[float]]]:
"""Decode RBM head outputs into per-frame floats.
Args:
progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities.
is_discrete_mode: if True the progress logits get a softmax over bins
and are projected onto bin centers via :func:`convert_bins_to_continuous`.
Returns:
Dict with ``progress_pred`` and ``success_probs``, each a list of
length ``B`` of per-frame float lists.
"""
progress_pred: list[list[float]] = []
success_probs: list[list[float]] = []
if progress_logits is not None:
for sample_logits in progress_logits:
if is_discrete_mode:
continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu())
progress_pred.append(continuous.flatten().tolist())
else:
progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist())
if success_logits is not None:
for sample_logits in success_logits:
success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist())
return {"progress_pred": progress_pred, "success_probs": success_probs}
class RobometerRewardModel(PreTrainedRewardModel):
"""Robometer (RBM) reward model — inference-only LeRobot port.
Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three
prediction heads from the paper (progress, success, preference). At
inference time only the progress and success heads are queried; the
preference head is kept on the module so the published ``Robometer-4B``
safetensors load unchanged.
"""
name = "robometer"
config_class = RobometerConfig
def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None:
require_package("transformers", extra="robometer")
super().__init__(config)
self.config = config
# Two backbone-build paths (EO-1 style, branched on ``pretrained_path``):
#
# - Fresh training (``pretrained_path is None``): download the base
# Qwen weights and resize the embed table to match
# ``vlm_config.text_config.vocab_size`` — populated deterministically
# in ``RobometerConfig.__post_init__`` as
# ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``
#
# - Loading a saved checkpoint (``pretrained_path`` is set): rebuild
# the empty architecture from ``vlm_config`` via
# ``AutoModelForImageTextToText.from_config`` so the subsequent
# ``model.safetensors`` load is a direct fill of the right shape —
# no redundant Qwen weight download.
torch_dtype = _torch_dtype(config.torch_dtype)
if config.pretrained_path is None:
self.model = AutoModelForImageTextToText.from_pretrained(
config.base_model_id,
dtype=torch_dtype,
trust_remote_code=True,
)
target_vocab = config.vlm_config["text_config"]["vocab_size"]
self.model.resize_token_embeddings(target_vocab)
else:
self.model = AutoModelForImageTextToText.from_config(
config.vlm_backbone_config,
dtype=torch_dtype,
trust_remote_code=True,
)
# All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`.
# Falls back to the top-level `hidden_size` so future non-multimodal
# variants would still resolve.
backbone_config = self.model.config
text_config = getattr(backbone_config, "text_config", None)
hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None
if hidden_size is None:
hidden_size = getattr(backbone_config, "hidden_size", None)
if hidden_size is None:
raise AttributeError(
f"Could not infer hidden_size from backbone config of {config.base_model_id}"
)
hidden_dim = int(hidden_size)
# Robometer's three prediction heads + frame-pool attention.
progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1
self.progress_head = RobometerPredictionHead(
hidden_dim,
progress_output,
dropout=dropout,
with_sigmoid=not config.use_discrete_progress,
)
self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False)
# Match the dtype of the loaded base model so weight loading is a no-op cast.
model_dtype = next(self.model.parameters()).dtype
self.progress_head.to(dtype=model_dtype)
self.preference_head.to(dtype=model_dtype)
self.success_head.to(dtype=model_dtype)
self.frame_pool_attn.to(dtype=model_dtype)
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
inputs = {
key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"]
for key in ROBOMETER_INPUT_KEYS
if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch
}
if "input_ids" not in inputs:
raise KeyError(
f"Robometer batch missing pre-encoded inputs (expected "
f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the "
"RobometerEncoderProcessorStep ran before `compute_reward`."
)
device = next(self.model.parameters()).device
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
self.eval()
with torch.no_grad():
progress_logits, success_logits = self._compute_rbm_logits(inputs)
decoded = decode_progress_outputs(
progress_logits,
success_logits,
is_discrete_mode=self.config.use_discrete_progress,
)
values = (
decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"]
)
rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values])
if self.config.reward_output == "success":
rewards = (rewards > self.config.success_threshold).float()
else:
# Match upstream Robometer's ``extract_rewards_from_output``: per-frame
# progress predictions are clamped to ``[0, 1]`` before being returned.
rewards = rewards.clamp(0.0, 1.0)
return rewards.to(self.config.device or "cpu")
def _compute_rbm_logits(
self,
inputs: dict[str, Any],
) -> tuple[Tensor, Tensor]:
"""Run the Qwen3-VL backbone and apply Robometer's heads.
``inputs`` is the encoded batch produced by
:class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well
as Robometer-specific metadata (``prog_token_id``,
``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``)
the metadata is popped here so the rest can be forwarded straight to
the Qwen model.
Returns ``(progress_logits, success_logits)``. Shapes:
- ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
- ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time).
"""
prog_token_id = inputs.pop("prog_token_id", None)
vision_start_token_id = inputs.pop("vision_start_token_id", None)
vision_end_token_id = inputs.pop("vision_end_token_id", None)
video_merge_size = inputs.pop("video_merge_size", 14)
# Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the
# full hidden-state tuple and take the last layer. This matches the
# `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main).
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
hidden_state = (
outputs.hidden_states[-1]
if getattr(outputs, "hidden_states", None)
else outputs.last_hidden_state
)
input_ids = inputs["input_ids"]
if self.config.use_per_frame_progress_token:
if prog_token_id is None:
raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)")
return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id)
if self.config.use_multi_image:
if vision_start_token_id is None or vision_end_token_id is None:
raise KeyError(
"`vision_start_token_id` / `vision_end_token_id` missing in batch "
"(run RobometerEncoderProcessorStep first)"
)
return self._process_multi_image_frames(
hidden_state,
input_ids,
start_id=vision_start_token_id,
end_id=vision_end_token_id,
)
video_grid_thw = inputs.get("video_grid_thw")
if video_grid_thw is None:
raise ValueError("video_grid_thw is required for video-mode Robometer inference")
if vision_start_token_id is None:
raise KeyError("`vision_start_token_id` missing in batch")
return self._process_video_frames(
hidden_state,
input_ids,
video_grid_thw,
start_id=vision_start_token_id,
merge_size=video_merge_size,
)
def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]:
"""Apply progress + success heads to a tensor of frame embeddings."""
progress_out = self.progress_head(frame_embeddings)
progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out)
success = _squeeze_last_safe(self.success_head(frame_embeddings))
return progress, success
def _process_token_extraction(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
prog_token_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success from ``<|prog_token|>`` positions."""
token_mask = input_ids == prog_token_id
batch_indices, positions = token_mask.nonzero(as_tuple=True)
if positions.numel() == 0:
raise ValueError("`<|prog_token|>` not found in any sequence")
per_sample_hidden = [
hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0])
]
progress_list, success_list = [], []
for embeddings in per_sample_hidden:
if embeddings.shape[0] == 0:
raise ValueError("`<|prog_token|>` missing in a sequence")
progress, success = self._apply_heads_to_hidden_states(embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _process_multi_image_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
start_id: int,
end_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in multi-image mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
frame_embeddings = self._extract_hidden_states_from_token_pairs(
seq_hidden, seq_ids, start_id, end_id
)
progress, success = self._apply_heads_to_hidden_states(frame_embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _extract_hidden_states_from_token_pairs(
self,
hidden_state: Tensor,
input_ids: Tensor,
start_id: int,
end_id: int,
) -> Tensor:
start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0]
end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
if start_positions.numel() != end_positions.numel():
raise ValueError(
f"Mismatched vision token counts: {start_positions.numel()} start vs "
f"{end_positions.numel()} end"
)
frames: list[Tensor] = []
for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True):
if start >= end:
raise ValueError(f"Invalid vision token pair: start={start} end={end}")
patch_tokens = hidden_state[start + 1 : end]
if patch_tokens.shape[0] == 0:
frames.append((hidden_state[start] + hidden_state[end]) / 2.0)
continue
pooling = self.config.frame_pooling
if pooling == "mean":
frames.append(patch_tokens.mean(dim=0))
elif pooling == "boundary":
frames.append(patch_tokens[-1])
else: # attention
scores = (
self.frame_pool_attn(patch_tokens).squeeze(-1)
/ self.config.frame_pooling_attn_temperature
)
weights = torch.softmax(scores, dim=0).unsqueeze(-1)
frames.append((weights * patch_tokens).sum(dim=0))
return torch.stack(frames)
def _process_video_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
video_grid_thw: Tensor,
*,
start_id: int,
merge_size: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in video mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist())
tokens_per_frame = (h_dim * w_dim) // (merge_size**2)
cursor = start_positions[0].item()
frame_embeddings: list[Tensor] = []
for _ in range(t_dim):
if self.config.average_temporal_patches:
patch = seq_hidden[cursor : cursor + tokens_per_frame]
frame_embeddings.append(patch.mean(dim=0))
else:
frame_embeddings.append(seq_hidden[cursor + tokens_per_frame])
cursor += tokens_per_frame
stacked = torch.stack(frame_embeddings)
progress, success = self._apply_heads_to_hidden_states(stacked)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
@@ -0,0 +1,338 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Robometer pre/post processing pipelines."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from PIL import Image
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
policy_action_to_transition,
)
from lerobot.rewards.robometer.configuration_robometer import (
ROBOMETER_SPECIAL_TOKENS,
RobometerConfig,
)
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor
else:
AutoProcessor = None
PROGRESS_PROMPT = (
"The task for the robot is '{task}'. Given the trajectory video, predict "
"the task progress at each frame, how far along the robot is towards "
"completing the task, a float between 0 and 1, where 0 is the starting "
"state and 1 is when the task is completed. If the robot is not "
"performing the same task, predict 0 progress."
)
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
if frames.ndim != 4:
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
if frames.dtype != np.uint8:
frames = np.clip(frames, 0, 255).astype(np.uint8)
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array."""
if max_frames is not None:
video = video[-max_frames:]
if video.shape[1] in (1, 3):
video = video.permute(0, 2, 3, 1)
elif video.shape[-1] not in (1, 3):
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
array = video.detach().cpu().numpy()
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
array = array * 255.0
return np.clip(array, 0, 255).astype(np.uint8)
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
if task is None:
task = default
if task is None:
raise KeyError("Robometer expected a task description in complementary data")
if isinstance(task, str):
return [task] * batch_size
if isinstance(task, tuple):
task = list(task)
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}")
if len(task) == 1 and batch_size > 1:
return task * batch_size
if len(task) != batch_size:
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
return task
@dataclass
@ProcessorStepRegistry.register(name="robometer_encoder")
class RobometerEncoderProcessorStep(ProcessorStep):
"""Encode raw frames + task into Qwen-VL tensors for the Robometer model.
Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and
registers Robometer's special tokens on the tokenizer. The matching
embedding resize happens model-side in
:meth:`RobometerRewardModel.__init__`.
At call time the step reads:
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
- ``complementary_data[task_key]``: a string or list of strings.
and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}<name>"]`` for:
- the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``,
``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ...
- Robometer-specific token ids consumed by the model heads:
``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``,
``video_merge_size``.
"""
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
max_length: int = 1024
_processor: Any = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
require_package("transformers", extra="robometer")
require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils")
self._processor = AutoProcessor.from_pretrained(
self.base_model_id,
trust_remote_code=True,
do_sample_frames=False,
padding_side="right",
)
# Register Robometer's special tokens on the tokenizer. The matching
# embedding resize happens model-side in `RobometerRewardModel.__init__`.
tokenizer = self._processor.tokenizer
# Qwen tokenizers may not define a pad token, but batched prompts/videos
# require padding, so reuse EOS as the padding token.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
for token in ROBOMETER_SPECIAL_TOKENS:
if token not in tokenizer.get_vocab():
tokenizer.add_special_tokens({"additional_special_tokens": [token]})
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if not isinstance(observation, dict):
raise ValueError("RobometerEncoderProcessorStep requires an observation dict")
if self.image_key not in observation:
raise KeyError(f"Robometer expected image key {self.image_key!r} in observation")
frames = observation[self.image_key]
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
if tensor.ndim == 4:
tensor = tensor.unsqueeze(1)
elif tensor.ndim != 5:
raise ValueError(
f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
)
batch_size = tensor.shape[0]
tasks = _expand_tasks(
complementary.get(self.task_key, self.default_task),
batch_size=batch_size,
default=self.default_task,
)
samples = [
(_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size)
]
encoded = self.encode_samples(samples)
new_observation = dict(observation)
for key, value in encoded.items():
new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]:
"""Run the Qwen-VL processor on a list of ``(frames, task)`` samples."""
from qwen_vl_utils import process_vision_info
conversations = [self._build_conversation(frames, task) for frames, task in samples]
texts = [
self._processor.apply_chat_template(
msg,
tokenize=False,
add_generation_prompt=False,
add_vision_id=True,
enable_thinking=False,
fps=1,
)
for msg in conversations
]
process_kwargs: dict[str, Any] = {
"return_video_kwargs": True,
"return_video_metadata": True,
}
image_processor = getattr(self._processor, "image_processor", None)
if image_processor is not None and hasattr(image_processor, "patch_size"):
process_kwargs["image_patch_size"] = image_processor.patch_size
image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs)
videos: list[Any] | None = None
video_metadatas: list[Any] | None = None
if video_inputs:
if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2:
videos_seq, metadatas_seq = zip(*video_inputs, strict=False)
videos = list(videos_seq)
video_metadatas = list(metadatas_seq)
else:
videos = list(video_inputs)
processor_kwargs: dict[str, Any] = {
"text": texts,
"images": image_inputs,
"padding": True,
"truncation": False,
"max_length": self.max_length,
"return_tensors": "pt",
"do_resize": False,
}
if videos is not None:
processor_kwargs["videos"] = videos
if video_metadatas is not None:
processor_kwargs["video_metadata"] = video_metadatas
if video_kwargs:
processor_kwargs.update(video_kwargs)
encoded = self._processor(**processor_kwargs)
# Write Robometer-specific token ids and the video patch merge size into
# the encoded batch so `RobometerRewardModel` doesn't need its own
# tokenizer at inference (EO1-style separation: the processor owns the
# tokenizer, the model owns the backbone and heads).
tokenizer = self._processor.tokenizer
encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>")
encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>")
encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>")
video_processor = getattr(self._processor, "video_processor", None)
encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14))
return encoded
def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]:
pil_frames = _frames_to_pil(frames)
prompt = PROGRESS_PROMPT.format(task=task)
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
if self.use_multi_image:
for image in pil_frames:
content.append({"type": "image", "image": image})
if self.use_per_frame_progress_token:
content.append({"type": "text", "text": "<|prog_token|>"})
else:
content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0})
return [{"role": "user", "content": content}]
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"base_model_id": self.base_model_id,
"image_key": self.image_key,
"task_key": self.task_key,
"default_task": self.default_task,
"max_frames": self.max_frames,
"use_multi_image": self.use_multi_image,
"use_per_frame_progress_token": self.use_per_frame_progress_token,
"max_length": self.max_length,
}
def make_robometer_pre_post_processors(
config: RobometerConfig,
dataset_stats: dict[str, dict[str, Any]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
The preprocessor adds a batch dimension if needed, runs Robometer's
encoder, and moves everything to the configured device. The
postprocessor is the identity since Robometer outputs a single reward
tensor.
"""
del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor.
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=config.max_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
+4
View File
@@ -23,6 +23,7 @@ from .configs import (
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutStrategyConfig,
@@ -49,6 +50,7 @@ from .inference import (
from .strategies import (
BaseStrategy,
DAggerStrategy,
EpisodicStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
@@ -66,6 +68,8 @@ __all__ = [
"HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig",
"EpisodicStrategy",
"EpisodicStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
+36 -1
View File
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("episodic")
@dataclass
class EpisodicStrategyConfig(RolloutStrategyConfig):
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
Keyboard controls:
Right arrow end current episode or reset phase early
Left arrow discard current episode and re-record
Escape stop recording session
In between episodes:
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
- else, the robot is moved smoothly to the position of the teleop leader.
"""
# This only applies if there are no teleop leaders specified.
# When True (default), moves the robot back to the joint positions captured at startup.
# Otherwise, leave the robot in its current position.
reset_to_initial_position: bool = True
# Whether to turn on or off the leader -> follower smooth handover behavior.
# When False, fallback to follower -> leader handover.
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
smooth_leader_to_follower_handover: bool = True
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
@@ -229,7 +258,13 @@ class RolloutConfig:
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
needs_dataset = isinstance(
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
self.strategy,
(
SentryStrategyConfig,
HighlightStrategyConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
),
)
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
@@ -17,6 +17,7 @@
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .episodic import EpisodicStrategy
from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
@@ -27,6 +28,7 @@ __all__ = [
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"EpisodicStrategy",
"RolloutStrategy",
"SentryStrategy",
"create_strategy",
+14 -69
View File
@@ -56,10 +56,14 @@ from typing import Any
import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.common.control_utils import (
follower_smooth_move_to,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
@@ -171,64 +174,6 @@ class DAggerEvents:
self.upload_requested.clear()
# ---------------------------------------------------------------------------
# Teleoperator helpers
# ---------------------------------------------------------------------------
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def _follower_smooth_move_to(
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm
to the follower, we bring the follower to the teleop's current pose.
Both ``current`` and ``target`` must be in robot-action key space.
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
# ---------------------------------------------------------------------------
# Input device handlers
# ---------------------------------------------------------------------------
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
logger.info("Pausing engine - robot holds position")
engine.pause()
if _teleop_supports_feedback(teleop) and prev_action is not None:
if teleop_supports_feedback(teleop) and prev_action is not None:
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
logger.info("Smooth handover: moving leader arm to follower position")
_teleop_smooth_move_to(teleop, prev_action)
teleop_smooth_move_to(teleop, prev_action)
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
logger.info("Entering correction mode - human teleop control")
if not _teleop_supports_feedback(teleop) and prev_action is not None:
if not teleop_supports_feedback(teleop) and prev_action is not None:
logger.info("Smooth handover: sliding follower to teleop position")
obs = robot.get_observation()
teleop_action = teleop.get_action()
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
target = ctx.processors.robot_action_processor((processed, obs))
_follower_smooth_move_to(robot, prev_action, target)
follower_smooth_move_to(robot, prev_action, target)
# unlock the teleop for human control
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.disable_torque()
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.enable_torque()
elif new_phase == DAggerPhase.AUTONOMOUS:
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
engine.resume()
# release teleop before resuming the policy
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.disable_torque()
# ------------------------------------------------------------------
+335
View File
@@ -0,0 +1,335 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
- Policy drives the robot during each recording episode.
- An optional teleoperator can drive the robot during reset phases so the
operator can bring the environment back to its starting configuration.
If no teleop is connected the robot stays in its current position.
- Keyboard controls:
Right arrow end the current episode or reset phase early
Left arrow discard the current episode and re-record it
Escape stop the recording session
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
"""
from __future__ import annotations
import contextlib
import logging
import time
from lerobot.common.control_utils import (
follower_smooth_move_to,
init_keyboard_listener,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import log_rerun_data
from ..configs import EpisodicStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
logger = logging.getLogger(__name__)
class EpisodicStrategy(RolloutStrategy):
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
follows every episode (except the last) so the operator can manually
reset the environment. During the reset phase, an optional teleoperator
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
The policy state (hidden state, RTC queue, interpolator) is reset at
the start of each recording episode.
Keyboard events:
right arrow end current episode or reset phase early
left arrow discard & re-record current episode
ESC stop the session
"""
config: EpisodicStrategyConfig
def __init__(self, config: EpisodicStrategyConfig) -> None:
super().__init__(config)
self._listener = None
self._events: dict | None = None
def setup(self, ctx: RolloutContext) -> None:
"""Start the inference engine and attach the keyboard listener."""
self._init_engine(ctx)
self._listener, self._events = init_keyboard_listener()
logger.info("Episodic strategy ready")
def run(self, ctx: RolloutContext) -> None:
"""Main multi-episode recording loop."""
cfg = ctx.runtime.cfg
dataset_cfg = cfg.dataset
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
features = ctx.data.dataset_features
fps = cfg.fps
episode_time_s = dataset_cfg.episode_time_s
reset_time_s = dataset_cfg.reset_time_s
num_episodes = dataset_cfg.num_episodes
single_task = dataset_cfg.single_task or cfg.task
play_sounds = cfg.play_sounds
display_compressed = (
True
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
else cfg.display_compressed_images
)
with VideoEncodingManager(dataset):
try:
recorded_episodes = 0
while recorded_episodes < num_episodes and not events["stop_recording"]:
if ctx.runtime.shutdown_event.is_set():
break
# Reset policy state at episode start (discard leftover hidden state / queue)
self._engine.reset()
self._interpolator.reset()
self._engine.resume()
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
self._policy_loop(
ctx=ctx,
robot=robot,
events=events,
features=features,
fps=fps,
control_time_s=episode_time_s,
dataset=dataset,
single_task=single_task,
)
# Reset phase, skip after the last episode (but run when re-recording)
if not events["stop_recording"] and (
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
if teleop:
# Smooth handover so the transition to teleop control is jerk-free.
# For actuated teleops: drive the leader arm to the follower's current
# position so the operator takes over without fighting the arm.
# For non-actuated teleops: slide the follower to the teleop's current
# pose instead, since the leader cannot be driven.
obs = robot.get_observation()
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
if (
teleop_supports_feedback(teleop)
and self.config.smooth_leader_to_follower_handover
):
logger.info("Smooth handover: moving leader arm to follower position")
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
teleop.disable_torque()
else:
logger.info("Smooth handover: sliding follower to teleop position")
teleop_action = teleop.get_action()
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
target = ctx.processors.robot_action_processor((processed, obs))
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
elif self.config.reset_to_initial_position:
# No teleop: return the robot to its startup position.
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
self._reset_loop(
ctx=ctx,
robot=robot,
teleop=teleop,
events=events,
fps=fps,
control_time_s=reset_time_s,
display_data=cfg.display_data,
display_compressed=display_compressed,
)
if events["rerecord_episode"]:
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
# returns to its initial joint positions captured at startup
if not teleop and self.config.reset_to_initial_position:
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
continue
dataset.save_episode()
recorded_episodes += 1
finally:
# Save any frames buffered in the current episode so an unexpected
# exception or KeyboardInterrupt does not silently drop recorded data.
# suppress: save_episode raises if the buffer is empty (nothing to lose).
logger.info("Episodic control loop ended — saving any in-progress episode")
with contextlib.suppress(Exception):
dataset.save_episode()
def _policy_loop(
self,
ctx: RolloutContext,
robot,
events: dict,
features: dict,
fps: float,
control_time_s: float,
dataset,
single_task: str,
) -> None:
"""Policy-driven recording loop for a single episode."""
interpolator = self._interpolator
control_interval = interpolator.get_control_interval(fps)
timestamp = 0.0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
if ctx.runtime.shutdown_event.is_set():
break
obs = robot.get_observation()
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
dt = time.perf_counter() - loop_start
sleep_t = control_interval - dt
if sleep_t < 0:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
"Dataset frames might be dropped and robot control might be unstable. "
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
"3) CPU starvation"
)
precise_sleep(max(sleep_t, 0.0))
timestamp = time.perf_counter() - start_t
def _reset_loop(
self,
ctx: RolloutContext,
robot,
teleop,
events: dict,
fps: float,
control_time_s: float,
display_data: bool,
display_compressed: bool,
) -> None:
"""Reset-phase loop: teleop drives the robot if available, no recording."""
processors = ctx.processors
control_interval = 1.0 / fps
timestamp = 0.0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
if ctx.runtime.shutdown_event.is_set():
break
obs = robot.get_observation()
if teleop is not None:
act = teleop.get_action()
act_teleop = processors.teleop_action_processor((act, obs))
robot_action = processors.robot_action_processor((act_teleop, obs))
robot.send_action(robot_action)
if display_data:
obs_processed = processors.robot_observation_processor(obs)
log_rerun_data(
observation=obs_processed,
action=act_teleop,
compress_images=display_compressed,
)
dt = time.perf_counter() - loop_start
sleep_t = control_interval - dt
precise_sleep(max(sleep_t, 0.0))
timestamp = time.perf_counter() - start_t
def teardown(self, ctx: RolloutContext) -> None:
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
cfg = ctx.runtime.cfg
play_sounds = cfg.play_sounds
log_say("Stop recording", play_sounds, blocking=True)
if not is_headless() and self._listener is not None:
self._listener.stop()
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if (
cfg.dataset is not None
and cfg.dataset.push_to_hub
and ctx.data.dataset is not None
and safe_push_to_hub(
ctx.data.dataset,
tags=cfg.dataset.tags,
private=cfg.dataset.private,
)
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(
ctx.hardware,
return_to_initial_position=cfg.return_to_initial_position,
)
log_say("Exiting", play_sounds)
logger.info("Episodic strategy teardown complete")
+6 -1
View File
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
from .base import BaseStrategy
from .core import RolloutStrategy
from .dagger import DAggerStrategy
from .episodic import EpisodicStrategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
return HighlightStrategy(config)
if config.type == "dagger":
return DAggerStrategy(config)
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
if config.type == "episodic":
return EpisodicStrategy(config)
raise ValueError(
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
)
+206
View File
@@ -0,0 +1,206 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``lerobot-annotate`` — populate ``language_persistent`` and
``language_events`` columns on a LeRobot dataset.
Annotations live directly in ``data/chunk-*/file-*.parquet``.
Example:
uv run lerobot-annotate \\
--root=/path/to/dataset \\
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
For distributed runs, see ``examples/annotations/run_hf_job.py``.
"""
import logging
from pathlib import Path
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
from lerobot.annotations.steerable_pipeline.modules import (
GeneralVqaModule,
InterjectionsAndSpeechModule,
PlanSubtasksMemoryModule,
)
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
from lerobot.configs import parser
logger = logging.getLogger(__name__)
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
if cfg.root is not None:
return Path(cfg.root)
if cfg.repo_id is not None:
from huggingface_hub import snapshot_download
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
raise ValueError("Either --root or --repo_id must be provided.")
@parser.wrap()
def annotate(cfg: AnnotationPipelineConfig) -> None:
"""Run the steerable annotation pipeline against a dataset."""
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
root = _resolve_root(cfg)
logger.info("annotate: root=%s", root)
vlm = make_vlm_client(cfg.vlm)
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend)
# Surface the resolved cameras up front so a silent vqa-module no-op
# is obvious in job output rather than discovered post-hoc by counting
# parquet rows.
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
logger.info(
"annotate: frame_provider default camera=%r, all cameras=%s",
getattr(frame_provider, "camera_key", None),
cam_keys,
)
if cfg.vqa.enabled and not cam_keys:
logger.warning(
"annotate: the vqa module is enabled but no cameras were "
"resolved — it will produce zero VQA rows. Check "
"meta/info.json for observation.images.* features, or pass "
"--vlm.camera_key=<key> to seed the cameras list."
)
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
interjections = InterjectionsAndSpeechModule(
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
)
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
writer = LanguageColumnsWriter()
validator = StagingValidator(
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
)
executor = Executor(
config=cfg,
plan=plan,
interjections=interjections,
vqa=vqa,
writer=writer,
validator=validator,
)
summary = executor.run(root)
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
for phase in summary.phases:
logger.info(
"annotate: phase=%s processed=%d skipped=%d",
phase.name,
phase.episodes_processed,
phase.episodes_skipped,
)
if summary.validation_report.warnings:
for w in summary.validation_report.warnings:
logger.warning(w)
if cfg.push_to_hub:
if cfg.repo_id is None and cfg.new_repo_id is None:
raise ValueError(
"--push_to_hub requires --repo_id or --new_repo_id (the dataset repo to push to)."
)
_push_to_hub(root, cfg)
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
"""Upload the annotated dataset directory to the Hub.
Pushes to ``cfg.new_repo_id`` when set, otherwise back to ``cfg.repo_id``.
"""
from huggingface_hub import HfApi # noqa: PLC0415
repo_id = cfg.new_repo_id or cfg.repo_id
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
api = HfApi()
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
api.create_repo(
repo_id=repo_id,
repo_type="dataset",
private=cfg.push_private,
exist_ok=True,
)
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
commit_info = api.upload_folder(
folder_path=str(root),
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
)
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
# resolves the dataset revision via ``get_safe_version`` which scans
# for tags like ``v3.0``; without a tag it raises
# ``RevisionNotFoundError``. Read the version straight from the
# dataset's own ``meta/info.json`` so we tag whatever the writer
# actually wrote (no accidental drift if the codebase floor moves).
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
info_path = root / "meta" / "info.json"
version_tag = CODEBASE_VERSION
if info_path.exists():
try:
from lerobot.utils.io_utils import load_json # noqa: PLC0415
info = load_json(info_path)
ds_version = info.get("codebase_version")
if isinstance(ds_version, str) and ds_version.startswith("v"):
version_tag = ds_version
except Exception as exc: # noqa: BLE001
print(
f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}",
flush=True,
)
revision = getattr(commit_info, "oid", None)
tag_kwargs = {
"repo_id": repo_id,
"tag": version_tag,
"repo_type": "dataset",
}
if revision is not None:
tag_kwargs["revision"] = revision
try:
from contextlib import suppress # noqa: PLC0415
from huggingface_hub.errors import RevisionNotFoundError # noqa: PLC0415
with suppress(RevisionNotFoundError):
api.delete_tag(repo_id, tag=version_tag, repo_type="dataset")
api.create_tag(**tag_kwargs)
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
except Exception as exc: # noqa: BLE001
print(
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
"Run: from huggingface_hub import HfApi; "
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
flush=True,
)
def main() -> None:
annotate()
if __name__ == "__main__":
main()
@@ -94,6 +94,14 @@ Merge multiple datasets from a list of local dataset paths:
--operation.repo_ids "['pusht_train', 'pusht_val']" \
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
Merge multiple datasets while keeping one file per source file (no video/data stitching):
lerobot-edit-dataset \
--new_repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" \
--operation.concatenate_videos false \
--operation.concatenate_data false
Remove camera feature:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
@@ -257,6 +265,9 @@ class SplitConfig(OperationConfig):
class MergeConfig(OperationConfig):
repo_ids: list[str] | None = None
roots: list[str] | None = None
# When False, keep one file per source file instead of packing into shards.
concatenate_videos: bool = True
concatenate_data: bool = True
@OperationConfig.register_subclass("remove_feature")
@@ -461,6 +472,8 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
datasets,
output_repo_id=cfg.new_repo_id,
output_dir=output_dir,
concatenate_videos=cfg.operation.concatenate_videos,
concatenate_data=cfg.operation.concatenate_data,
)
logging.info(f"Merged dataset saved to {output_dir}")
+159 -19
View File
@@ -72,8 +72,9 @@ from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
from lerobot.configs import parser
from lerobot.configs import FeatureType, parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs import (
check_env_attributes_and_types,
close_envs,
@@ -84,7 +85,7 @@ from lerobot.envs import (
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
@@ -95,6 +96,81 @@ from lerobot.utils.utils import (
)
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
If raw_obs is provided, visual feature shapes are inferred from the actual observation
to avoid mismatches between the env config and the real observation resolution.
"""
features = {}
for key, ft in env_features.items():
if ft.type is FeatureType.VISUAL:
shape = tuple(ft.shape)
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
shape = raw_obs[key].shape[1:] # strip batch dim
elif raw_obs is not None and "pixels" in raw_obs:
pixels = raw_obs["pixels"]
if isinstance(pixels, dict):
for cam_name, img in pixels.items():
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
shape = img.shape[1:] # strip batch dim
elif key in ("pixels", OBS_IMAGE):
shape = pixels.shape[1:] # strip batch dim
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
else:
shape = tuple(ft.shape)
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
shape = raw_obs[key].shape[1:] # strip batch dim
features[key] = {"dtype": "float32", "shape": shape, "names": None}
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
return features
def _build_raw_frame(
raw_obs: dict,
env_idx: int,
action: np.ndarray,
reward: float,
success: bool,
done: bool,
task: str,
env_features: dict,
) -> dict:
"""Build a dataset frame from raw env observations for one env index.
Keys in the frame match the keys in env_features so they align with the
dataset schema created by _env_features_to_dataset_features().
"""
frame: dict[str, Any] = {}
for key in env_features:
if key == ACTION:
continue
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
for cam_name, img in raw_obs["pixels"].items():
candidate = f"{OBS_IMAGES}.{cam_name}"
if candidate == key:
frame[key] = img[env_idx]
if key in frame:
continue
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
frame[key] = raw_obs["pixels"][env_idx]
continue
raw_key = key
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
val = raw_obs[raw_key][env_idx]
if val.dtype == np.float64:
val = val.astype(np.float32)
frame[key] = val
frame[ACTION] = action
frame["next.reward"] = np.atleast_1d(np.float32(reward))
frame["next.success"] = np.atleast_1d(np.bool_(success))
frame["next.done"] = np.atleast_1d(np.bool_(done))
frame["task"] = task
return frame
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
@@ -105,6 +181,7 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
recording_dataset: Any | None = None,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
@@ -145,6 +222,14 @@ def rollout(
if render_callback is not None:
render_callback(env)
raw_observation = deepcopy(observation) if recording_dataset is not None else None
task_desc = ""
if recording_dataset is not None:
try:
task_desc = list(env.call("task_description"))[0]
except (AttributeError, NotImplementedError):
task_desc = ""
all_observations = []
all_actions = []
all_rewards = []
@@ -217,6 +302,26 @@ def rollout(
else:
successes = [False] * env.num_envs
if recording_dataset is not None and raw_observation is not None:
prev_done = done.copy()
for env_idx in range(env.num_envs):
if prev_done[env_idx]:
continue
frame = _build_raw_frame(
raw_observation,
env_idx,
action_numpy[env_idx],
reward[env_idx],
successes[env_idx],
bool(terminated[env_idx] | truncated[env_idx]),
task_desc,
recording_dataset.features,
)
recording_dataset.add_frame(frame)
if terminated[env_idx] or truncated[env_idx]:
recording_dataset.save_episode()
raw_observation = deepcopy(observation)
# Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit.
# This ensures that the rollout always terminates cleanly at `max_steps`,
@@ -273,6 +378,7 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
recording_dataset: Any | None = None,
) -> dict:
"""
Args:
@@ -361,6 +467,7 @@ def eval_policy(
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
recording_dataset=recording_dataset,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -563,6 +670,10 @@ def eval_main(cfg: EvalPipelineConfig):
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
max_episodes_rendered = 0 if cfg.eval.recording else 10
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
@@ -572,10 +683,13 @@ def eval_main(cfg: EvalPipelineConfig):
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=False,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
recording_dir=recording_dir,
env_features=cfg.env.features if cfg.eval.recording else None,
)
print("Overall Aggregated Metrics:")
print(info["overall"])
@@ -618,6 +732,7 @@ def eval_one(
videos_dir: Path | None,
return_episode_data: bool,
start_seed: int | None,
recording_dataset: Any | None = None,
) -> TaskMetrics:
"""Evaluates one task_id of one suite using the provided vec env."""
@@ -635,6 +750,7 @@ def eval_one(
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dataset=recording_dataset,
)
per_episode = task_result["per_episode"]
@@ -661,6 +777,8 @@ def run_one(
videos_dir: Path | None,
return_episode_data: bool,
start_seed: int | None,
recording_dir: Path | None = None,
env_features: dict | None = None,
):
"""
Run eval_one for a single (task_group, task_id, env).
@@ -672,21 +790,39 @@ def run_one(
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
task_videos_dir.mkdir(parents=True, exist_ok=True)
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
)
# ensure we always provide video_paths key to simplify accumulation
recording_dataset = None
if recording_dir is not None and env_features is not None:
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
fps = env.unwrapped.metadata.get("render_fps", 30)
sample_obs, _ = env.reset()
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
recording_dataset = LeRobotDataset.create(
repo_id=f"eval_{task_group}_{task_id}",
fps=fps,
features=features,
root=str(task_recording_dir),
use_videos=True,
)
try:
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dataset=recording_dataset,
)
finally:
if recording_dataset is not None:
recording_dataset.finalize()
if max_episodes_rendered > 0:
metrics.setdefault("video_paths", [])
return task_group, task_id, metrics
@@ -702,6 +838,8 @@ def eval_policy_all(
n_episodes: int,
*,
max_episodes_rendered: int = 0,
recording_dir: Path | None = None,
env_features: dict | None = None,
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
@@ -761,6 +899,8 @@ def eval_policy_all(
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dir=recording_dir,
env_features=env_features,
)
if max_parallel_tasks <= 1:
+13
View File
@@ -25,6 +25,7 @@ Strategies
--strategy.type=sentry Continuous recording with auto-upload
--strategy.type=highlight Ring buffer + keystroke save
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
--strategy.type=episodic Episode-oriented recording with reset phases
Inference backends
------------------
@@ -111,6 +112,18 @@ Usage examples
--display_data=true \\
--use_torch_compile=true
# Episodic mode — episode-oriented recording with reset phases
lerobot-rollout \\
--strategy.type=episodic \\
--policy.path=user/my_policy \\
--robot.type=so100_follower \\
--robot.port=/dev/ttyACM0 \\
--teleop.type=so100_leader \\
--teleop.port=/dev/ttyACM1 \\
--dataset.repo_id=user/rollout_episodic_data \\
--dataset.num_episodes=20 \\
--dataset.single_task="Grab the cube"
# Resume a previous sentry recording session
lerobot-rollout \\
--strategy.type=sentry \\
+92 -36
View File
@@ -36,6 +36,8 @@ from tqdm import tqdm
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
save_checkpoint,
update_last_checkpoint,
@@ -43,7 +45,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -99,6 +101,9 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Compute sample weights if a weighter is provided
sample_weights = None
weight_stats = None
@@ -158,6 +163,8 @@ def update_policy(
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
if torch.cuda.is_available():
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
return train_metrics, output_dict
@@ -232,14 +239,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: main process downloads first to avoid race conditions
# Dataset loading synchronization: the global main process downloads once to the shared
# dataset root, then a barrier lets every other rank read the already-populated copy.
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset
# Other ranks read from the shared copy populated by the main process.
if not is_main_process:
dataset = make_dataset(cfg)
@@ -292,19 +301,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
active_cfg = cfg.trainable_config
processor_pretrained_path = active_cfg.pretrained_path
if (
getattr(active_cfg, "use_relative_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
logging.warning(
"use_relative_actions=true with pretrained processors can skip relative transforms if "
"the checkpoint processors do not define them. Building processors from current policy config."
)
processor_pretrained_path = None
processor_kwargs = {}
postprocessor_kwargs = {}
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
processor_kwargs["dataset_stats"] = dataset.meta.stats
@@ -312,24 +310,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
processor_kwargs["dataset_meta"] = dataset.meta
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
preprocessor_overrides = {
"device_processor": {"device": device.type},
"normalizer_processor": {
"stats": dataset.meta.stats,
"features": {**policy.config.input_features, **policy.config.output_features},
"norm_map": policy.config.normalization_mapping,
},
"rename_observations_processor": {"rename_map": cfg.rename_map},
}
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
"rename_map": cfg.rename_map
}
postprocessor_kwargs["postprocessor_overrides"] = {
postprocessor_overrides = {
"unnormalizer_processor": {
"stats": dataset.meta.stats,
"features": policy.config.output_features,
"norm_map": policy.config.normalization_mapping,
},
}
if getattr(active_cfg, "use_relative_actions", False):
preprocessor_overrides["relative_actions_processor"] = {
"enabled": True,
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
"action_names": getattr(active_cfg, "action_feature_names", None),
}
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
processor_kwargs["postprocessor_overrides"] = postprocessor_overrides
if cfg.is_reward_model_training:
preprocessor, postprocessor = make_reward_pre_post_processors(
@@ -341,7 +346,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
@@ -389,15 +393,47 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
if not cfg.dataset.streaming:
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
# The order is a pure function of (seed, epoch), so every rank independently produces the
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
# use the values recorded in the checkpoint (falling back to the current ones for older
# ckpts that did not store them).
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
ckpt_num_processes = saved_num_processes or accelerator.num_processes
ckpt_batch_size = saved_batch_size or cfg.batch_size
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
logging.warning(
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
f"written with num_processes={saved_num_processes}. The data order resumes at the "
"right epoch/offset, but per-rank sample-exactness requires the same world size."
)
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
logging.warning(
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
"but per-rank sample-exactness requires the same batch size."
)
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
sampler.load_state_dict(sampler_state)
if is_main_process:
logging.info(
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
else:
shuffle = True
sampler = None
@@ -429,12 +465,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
# is actually optimizing. grad_norm and lr are already identical on every rank (post
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
# true straggler instead of rank 0's view.
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
# Derived from the post-reduce max step time; set once per log window on the main rank.
"samples_per_s": AverageMeter("smp/s", ":.0f"),
}
if torch.cuda.is_available():
# max() because headroom is gated by the worst-case rank.
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
# Keep global batch size for logging; MetricsTracker handles world size internally.
effective_batch_size = cfg.batch_size * accelerator.num_processes
@@ -486,21 +532,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
# Collective reduce must run on every rank, before the main-process gate below.
train_tracker.reduce_across_ranks()
if is_main_process:
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
# reflects the slowest rank — which is what actually gates the next iteration.
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
if step_time > 0:
train_tracker.samples_per_s = effective_batch_size / step_time
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
@@ -516,6 +570,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
num_processes=accelerator.num_processes,
batch_size=cfg.batch_size,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
@@ -13,77 +13,213 @@
[SmolVLA](https://huggingface.co/papers/2506.01844) is a compact, efficient vision-language-action model that achieves competitive performance at reduced computational costs and can be deployed on consumer-grade hardware.
{% elif model_name == "act" %}
[Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates.
{% elif model_name == "tdmpc" %}
[TD-MPC](https://huggingface.co/papers/2203.04955) combines model-free and model-based approaches to improve sample efficiency and performance in continuous control tasks by using a learned latent dynamics model and terminal value function.
{% elif model_name == "diffusion" %}
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
{% elif model_name == "vqbet" %}
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
{% elif model_name == "pi0" %}
**π₀ (Pi0)**
π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
**Model Overview**
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0).
[π₀ (Pi0)](https://www.physicalintelligence.company/blog/pi0) is a general-purpose robot foundation model from Physical Intelligence: a generalist Vision-Language-Action policy that understands visual inputs, interprets natural language instructions, and controls a variety of different robots across diverse tasks. The LeRobot implementation is adapted from their open-source OpenPI repository.
{% elif model_name == "pi05" %}
**π₀.₅ (Pi05) Policy**
π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
**Model Overview**
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
[π₀.₅ (Pi05)](https://www.physicalintelligence.company/blog/pi05) is a Vision-Language-Action model from Physical Intelligence designed for open-world generalization: it evolves π₀ to generalize to entirely new environments and situations that were never seen during training. The LeRobot implementation is adapted from their open-source OpenPI repository.
{% elif model_name == "molmoact2" %}
[MolmoAct2](https://allenai.org/blog/molmoact2) is an open robotics foundation model from the Allen Institute for AI (Ai2) that maps camera images and language instructions to robot action chunks. The LeRobot implementation supports training and evaluation of the regular MolmoAct2 model.
{% elif model_name == "vla_jepa" %}
[VLA-JEPA](https://arxiv.org/abs/2602.10098) is a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
{% elif model_name == "gaussian_actor" %}
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
{% elif model_name == "pi0_fast" %}
[π₀-FAST (Pi0-FAST)](https://www.physicalintelligence.company/research/fast) is a Vision-Language-Action model for general robot control, from Physical Intelligence. It models continuous robot actions with autoregressive next-token prediction using FAST (Frequency-space Action Sequence Tokenization), training up to 5x faster than diffusion-based π₀.
{% elif model_name == "eo1" %}
[EO-1](https://huggingface.co/papers/2508.21112) is a Vision-Language-Action model for general robot control. It pairs a Qwen2.5-VL backbone for vision-language understanding with a continuous flow-matching action head that denoises action chunks.
{% elif model_name == "groot" %}
[GR00T N1.5](https://github.com/NVIDIA/Isaac-GR00T) is an open, cross-embodiment foundation model from NVIDIA for generalized humanoid robot reasoning and skills. It takes language and images as input and uses a flow-matching action transformer to predict actions conditioned on vision, language, and proprioception.
{% elif model_name == "multi_task_dit" %}
[Multi-Task Diffusion Transformer (DiT)](https://huggingface.co/papers/2507.05331) extends Diffusion Policy with a large Diffusion Transformer and text + vision conditioning for multi-task robot learning. It supports both diffusion and flow-matching objectives and reaches high dexterity with only ~450M parameters.
{% elif model_name == "wall_x" %}
[WALL-OSS](https://huggingface.co/papers/2509.11766) is an open-source foundation model for embodied intelligence from XSquare Robot. Built on Qwen2.5-VL, it uses a tightly-coupled multimodal architecture with flow matching to unify semantic reasoning and high-frequency action generation for cross-embodiment control.
{% elif model_name == "xvla" %}
[X-VLA](https://huggingface.co/papers/2510.10274) is a soft-prompted, flow-matching Vision-Language-Action framework that treats each robot or hardware setup as a "task" encoded with a small set of learnable Soft Prompt embeddings, letting a single model reconcile diverse robot morphologies, sensors, and action spaces.
{% else %}
_Model type not recognized — please update this template._
This is a **{{ model_name }}** policy trained with [LeRobot](https://github.com/huggingface/lerobot).
{% endif %}
{% set diagrams = {
"smolvla": "https://cdn-uploads.huggingface.co/production/uploads/640e21ef3c82bd463ee5a76d/aooU0a3DMtYmy_1IWMaIM.png",
"pi0": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pi0%20(1).png",
"pi0_fast": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pifast.png",
"eo1": "https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png",
"groot": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png",
"wall_x": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/walloss-lerobot-paper.png",
"xvla": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
} %}
{% if diagrams.get(model_name) %}
<p align="center">
<img src="{{ diagrams[model_name] }}" alt="{{ model_name }} architecture" width="85%"/>
</p>
{% endif %}
<!-- A short demo is worth more than any description! Record a GIF/video of the policy
running on your robot, upload it to this repo, and embed it here:
<p align="center">
<img src="https://huggingface.co/<hf_user>/<policy_repo_id>/resolve/main/demo.gif" width="60%"/>
</p>
-->
This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
---
## How to Get Started with the Model
For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy).
Below is the short version on how to train and run inference/eval:
### Train from scratch
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/<dataset> \
--policy.type=act \
--output_dir=outputs/train/<desired_policy_repo_id> \
--job_name=lerobot_training \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/<desired_policy_repo_id>
--wandb.enable=true
```
_Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
### Evaluate the policy/run inference
```bash
lerobot-record \
--robot.type=so100_follower \
--dataset.repo_id=<hf_user>/eval_<dataset> \
--policy.path=<hf_user>/<desired_policy_repo_id> \
--episodes=10
```
Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
{% set policy_docs = {
"act": "act",
"smolvla": "smolvla",
"pi0": "pi0",
"pi0_fast": "pi0fast",
"pi05": "pi05",
"molmoact2": "molmoact2",
"vla_jepa": "vla_jepa",
"eo1": "eo1",
"groot": "groot",
"xvla": "xvla",
"multi_task_dit": "multi_task_dit",
"wall_x": "walloss"
} %}
{% if policy_docs.get(model_name) %}Learn how to train and run it in the [LeRobot {{ model_name }} guide](https://huggingface.co/docs/lerobot/main/en/{{ policy_docs[model_name] }}), or browse the [full documentation](https://huggingface.co/docs/lerobot/index).
{% else %}See the [full LeRobot documentation](https://huggingface.co/docs/lerobot/index).
{% endif %}
---
## Model Details
- **License:** {{ license | default("\[More Information Needed]", true) }}
{% if base_model %}- **Fine-tuned from:** [{{ base_model }}](https://huggingface.co/{{ base_model }})
{% endif %}{% if robot_type %}- **Robot type:** `{{ robot_type }}`
{% endif %}{% if cameras %}- **Cameras:** {% for camera in cameras %}`{{ camera }}`{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
{% if input_features or output_features %}
## Inputs & Outputs
The policy consumes these observation features and produces these action features.
{% if input_features %}
**Inputs**
| Feature | Type | Shape |
| --- | --- | --- |
{% for name, feature in input_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
{% endfor %}{% endif %}{% if output_features %}
**Outputs**
| Feature | Type | Shape |
| --- | --- | --- |
{% for name, feature in output_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
{% endfor %}{% endif %}{% endif %}
{% if dataset %}
## Training Dataset
- **Repository:** [{{ dataset.repo_id }}](https://huggingface.co/datasets/{{ dataset.repo_id }})
- **Episodes:** {{ dataset.episodes }}
- **Frames:** {{ dataset.frames }}
- **Frame rate:** {{ dataset.fps }} FPS
{% if dataset.tasks %}- **Task(s):** {% for task in dataset.tasks %}"{{ task }}"{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ dataset.repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
{% if training %}
## Training Configuration
| Setting | Value |
| --- | --- |
| Training steps | {{ training.steps }} |
| Batch size | {{ training.batch_size }} |
{% if training.optimizer %}| Optimizer | {{ training.optimizer }} |
{% endif %}{% if training.lr %}| Learning rate | {{ training.lr }} |
{% endif %}{% if training.seed is not none %}| Seed | {{ training.seed }} |
{% endif %}| LeRobot version | {{ training.lerobot_version }} |
{% endif %}
---
## How to Get Started with the Model
New to LeRobot? These guides cover the full workflow:
- **[Install LeRobot](https://huggingface.co/docs/lerobot/main/en/installation)** — set up the `lerobot` package.
- **[Hardware setup](https://huggingface.co/docs/lerobot/main/en/hardware_guide)** — assemble, wire, and calibrate your robot and cameras.
- **[Record data & train a policy](https://huggingface.co/docs/lerobot/en/il_robots)** — the end-to-end imitation-learning walkthrough.
- **[CLI cheat-sheet](https://huggingface.co/docs/lerobot/main/en/cheat-sheet)** — quick reference for the `lerobot-*` commands.
The short version to run and train this policy:
### Run the policy on your robot
```bash
lerobot-rollout \
--strategy.type=base \
--robot.type={{ robot_type | default("<your_robot_type>", true) }} \
--robot.port=<your_robot_port> \
--robot.cameras="{ <camera_1>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}, <camera_2>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}}" \
--policy.path={{ policy_repo_id | default("<hf_user>/<policy_repo_id>", true) }} \
--task="{% if dataset and dataset.tasks %}{{ dataset.tasks[0] }}{% else %}<your_task_description>{% endif %}" \
--duration=60
```
Replace the remaining `<...>` placeholders with your own values: `--robot.port` and the camera names/indices are specific to your machine, and the camera names must match the observation keys this policy was trained on.
When `--strategy.type=base` is used the script doesn't record the episodes. Skipping duration will make the policy run indefinitely. For more information look at [rollout documentation](https://huggingface.co/docs/lerobot/main/en/inference).
{% if base_model %}### Train your own policy
This policy type is usually fine-tuned from the pretrained base model [{{ base_model }}](https://huggingface.co/{{ base_model }}):
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/<dataset> \
--policy.path={{ base_model }} \
--output_dir=outputs/train/<policy_repo_id> \
--job_name=lerobot_training \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/<policy_repo_id> \
--wandb.enable=true
```
{% else %}### Train your own policy
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/<dataset> \
--policy.type={{ model_name }} \
--output_dir=outputs/train/<policy_repo_id> \
--job_name=lerobot_training \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/<policy_repo_id> \
--wandb.enable=true
```
{% endif %}
_Writes checkpoints to `outputs/train/<policy_repo_id>/checkpoints/`._
---
## Evaluation
<!-- Report real-robot results here: run the policy several times per task and count the
successes. Delete the "No evaluation results" line and fill in this table instead:
| Task | Trials | Successes | Success rate |
| ---- | ------ | --------- | ------------ |
| pick the lego brick | 10 | 8 | 80% |
Also worth noting: anything that affects difficulty (new object positions, lighting,
distractors, a different robot of the same type, ...).
-->
_No evaluation results have been provided for this policy yet._
---
## Citation
If you use this policy, please cite the method linked in the description above, along with LeRobot:
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
}
```
@@ -13,6 +13,8 @@
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
{% elif model_name == "sarm" %}
A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable.
{% elif model_name == "robometer" %}
ROBOMETER is a general-purpose video-language robotic reward model built on a fine-tuned Qwen3-VL-4B backbone with progress, preference, and success heads. Given a trajectory video and a task description, it predicts dense, frame-level task progress in [0, 1] and frame-level success probabilities for downstream robot learning, including offline RL, online RL, data filtering and retrieval, and automated failure detection.
{% elif model_name == "topreward" %}
TOPReward is a **zero-shot** reward model that extracts token log-probabilities from an off-the-shelf vision-language model (default Qwen3-VL) as a reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood of the instruction being true, with no fine-tuning required.
{% else %}

Some files were not shown because too many files have changed in this diff Show More