Compare commits

..

35 Commits

Author SHA1 Message Date
Gangwei XU 132ea975f0 fix(lingbot-va): align RoboTwin evaluation (#3784)
Thank you for the RoboTwin fix, and alignment!
2026-06-15 09:52:48 +02:00
Pepijn 961e0d9bcd docs(lingbot_va): condense processor normalization comments
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 12:04:48 +02:00
Pepijn 6496728025 docs(lingbot_va): point checkpoint paths at the lerobot org
The LeRobot-format checkpoints moved from pepijn223/* to lerobot/* (libero_long,
robotwin, base). Update the eval/train --policy.path examples accordingly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:58:31 +02:00
Pepijn 3b37bd0ca6 refactor(lingbot_va): use built-in UnnormalizerProcessorStep for actions
Replace the bespoke LingBotVAActionUnnormalizeStep with the standard
UnnormalizerProcessorStep in QUANTILES mode, which computes the identical
(action + 1) / 2 * (q99 - q01) + q01 mapping. The per-channel q01/q99 are stored
as the step's saved state (a safetensors file) and restored on load; a fresh build
has no action stats so the step is an identity passthrough.

The 3 Hub checkpoints (lerobot/lingbot_va_{libero_long,robotwin,base}) have been
re-uploaded with the new post-processor (policy_postprocessor.json +
*_unnormalizer_processor.safetensors); reloading from the Hub round-trips q01/q99.

- processor_lingbot_va.py: drop the custom step + registry; build the post-processor
  with UnnormalizerProcessorStep (explicit ACTION->QUANTILES norm_map so the
  preprocessor / training path is unchanged).
- tests: assert the built-in step is used, identity-when-no-stats, correct quantile
  unnormalization, and a save_pretrained/from_pretrained stats round-trip.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:57:31 +02:00
Pepijn 8e692e365c docs(lingbot_va): trim provenance comments; default wan path to base repo
- configuration_lingbot_va.py: drop the "──" decorations and the
  "(from transformer/config.json)" note; default wan_pretrained_path to
  robbyant/lingbot-va-base (has the frozen vae/text_encoder/tokenizer subfolders).
- modeling_lingbot_va.py: remove the vendored-code banner and the
  "(upstream wan_va/...)" section-header provenance/dash decorations; condense the
  transformer-dtype comment to one line.

No code changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:47:45 +02:00
Pepijn f617b2c2bf docs(lingbot_va): trim verbose comments
- configuration_lingbot_va.py: condense multi-line field comments to one-liners
  (keep the ── section headers).
- processor_lingbot_va.py: shorten the action-quantile explanation block.
- modeling_lingbot_va.py: drop the bare "# ----" separator rules, keeping the
  one-line section headers.

No code changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:31:05 +02:00
Pepijn c6a51b9b60 refactor(lingbot_va): drop hardcoded action quantiles; source from checkpoint
The LIBERO/RoboTwin action (un)normalization quantiles were hardcoded as module
constants in processor_lingbot_va.py. They are already serialized into each
checkpoint's policy_postprocessor.json (via LingBotVAActionUnnormalizeStep.get_config)
and restored on load by PolicyProcessorPipeline.from_pretrained, so the constants are
dead at eval/load time for the released checkpoints (verified: libero_long/robotwin/base
all carry their quantiles on the Hub).

- Remove LIBERO_ACTION_Q01/Q99, ROBOTWIN_ACTION_Q01/Q99 and _default_action_quantiles.
- make_lingbot_va_pre_post_processors now defaults a fresh (unconverted) build to a
  neutral [-1, 1] mapping (identity rescale); real per-benchmark stats come from the
  saved checkpoint (or postprocessor_overrides), analogous to dataset-stats normalization.
- Update the config doc comment to point at the checkpoint as the source of truth.
- Tests: replace the LIBERO-default assertion with a neutral-default check, and add a
  save_pretrained/from_pretrained round-trip guard for the quantile serialization.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:22:42 +02:00
Pepijn ab49c71c22 Update pyproject.toml
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-08 11:17:16 +02:00
Pepijn 459efef8a0 Update pyproject.toml
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-08 11:16:08 +02:00
Pepijn 5568ce7af1 Update lingbot_va.mdx
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-08 10:47:34 +02:00
Pepijn be0320a420 Merge branch 'main' into worktree-lingbot-va-port 2026-06-08 10:38:35 +02:00
pepijn223 5222f3a4a7 docs(lingbot_va): document EEF action-channel schema + camera order
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 16:11:35 +02:00
pepijn223 f9d12db9cf fix(lingbot_va): CI quality gate + fast-test collection
- Add tests/policies/lingbot_va/__init__.py so the test files don't clash by basename
  with tests/policies/vla_jepa/* under pytest's default import mode (fast-test collection error).
- Fix vendored typos flagged by the typos hook (pach_scale->patch_scale, total_tolen->
  total_token_len, stablized->stabilized) and a mypy union-attr in RoboTwinEnv._read_eef_pose.
- Apply Prettier formatting to docs/source/lingbot_va.mdx.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 15:46:37 +02:00
pepijn223 71aacda05e feat(lingbot_va): implement training / fine-tuning (flow-matching loss)
- Implement LingBotVAPolicy.forward(): dual-stream flow-matching training loss
  (latent + action, timestep-weighted, action-masked) ported from upstream train.py;
  VAE-encodes camera clips, UMT5-encodes the task, noises both streams, runs the
  block-causal flex-attention training pass (forward_train).
- training_loss_from_streams() core + _build_training_streams() data prep (action
  scatter into the 30-d space, multi-frame VAE encode incl. robotwin_tshape).
- get_optim_params returns only trainable transformer params (LoRA/PEFT friendly);
  VAE/UMT5 stay frozen. Training needs attn_mode='flex'.
- Add a tiny-config single-training-step test (forward->loss->backward->AdamW) and a
  Training/fine-tuning section in the docs.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 15:38:41 +02:00
pepijn223 e3deff00ad feat(lingbot_va): RoboTwin eef-pose eval, single-file model, Hub checkpoints
Make the LingBot-VA port runnable on both LIBERO and RoboTwin and clean up the
package to LeRobot conventions.

- Consolidate all vendored Wan2.2 model code (transformer, attention, VAE helpers,
  flow-matching scheduler, grid utils, flex-attention) into a single
  modeling_lingbot_va.py; remove the separate wan_*/schedulers modules.
- Move the fixed action (un)normalization quantiles out of the config and into the
  post-processor (LIBERO 7-DoF + RoboTwin 16-d eef); remove the conversion script in
  favour of ready-to-use LeRobot-format checkpoints on the Hub.
- Fixes found via on-sim validation: undo LIBERO's 180-degree image flip
  (image_hflip), encode obs as a multi-frame streaming-VAE clip, reset the streaming
  VAE cache between episodes, run the transformer in config.dtype, lazy-load frozen
  VAE/UMT5 by subfolder with the text encoder on CPU.
- RoboTwin: add an end-effector-pose action mode to RoboTwinEnv (16-d per-arm
  xyz+quat+gripper deltas composed onto the initial eef pose, executed via CuRobo IK)
  and the robotwin_tshape latent layout (full-res head + half-res wrists via a second
  streaming VAE) with the upstream RoboTwin action quantiles + camera mapping.
- Predicted-video saving works for both benchmarks; docs + tests updated.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 15:20:51 +02:00
Maxime Ellerbach 09808183ca feat(rollout): adding episodic strategy (#3717)
* feat(rollout): adding legacy strategy

* adding legacy to existing tests

* updating docs and docstring

* changing misleading docstring

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding extra guard like dagged with try except finally

* Potential fix for pull request finding

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding reset to initial position

* moving smooth teleop handover to control_utils and adding this behavior to legacy strategy

* reducing duration of the handover

* * renaming to episodic
* changing semantics of the docstring
* fixing leader - follower handover disable torque
* adding optionnal config to disable handover

* wiring the smooth_leader_follower_handover config

* renaming config smooth_leader_to_follower_handover

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-06-06 00:32:38 +02:00
Pepijn 4dfa8cea65 feat(policies): add LingBot-VA autoregressive video-action world model
Port the LingBot-VA policy (Wan2.2 dual-stream video+action world model) into
LeRobot, following the EO-1 / VLA-JEPA conventions. Covers inference, checkpoint
conversion, and predicted-video saving (training is deferred to a follow-up PR).

- Vendored Wan transformer/attention/flex/VAE/scheduler modules (key names preserved
  for near-identity conversion); torch SDPA default, flashattn/flex lazy-guarded.
- LingBotVAConfig (registered "lingbot_va") + processor with fixed-quantile action
  unnormalization; full dual-stream sampling loop with CFG, two flow-matching
  schedulers and KV cache, mapped onto select_action with observed-keyframe feedback.
- convert_lingbot_va_checkpoints.py (libero/robotwin variants): bundles the ~5B
  transformer, lazy-pulls the frozen VAE+UMT5 from the source repo.
- Predicted-video plumbing in lerobot_eval (predicted_frames_callback; opt-in via
  --policy.save_predicted_video) and ConstantWithWarmupSchedulerConfig.
- pyproject: widen diffusers-dep to <0.37, add lingbot_va + imageio-dep extras,
  add lingbot_va and (missing) eo1 to `all`.
- Factory + policies/__init__ wiring, docs page + toctree, and tests.

Note: the LIBERO success-rate correctness gate must be validated on a CUDA GPU
with the converted checkpoint.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-05 16:28:19 +02:00
Maxime Ellerbach 2e9cd87bbd feat(policies): add VLA-JEPA (#3568)
* first commit

* feat(policies): add VLA-JEPA

* feat(policies): add VLA-JEPA

* support vla_jepa

* (feat)policies: add VLA-JEPA

* linting

* adding deps to pyproject.toml

* updating uv lock

* adding guards to avoid needing transformers and diffusers for type checking and basic tests

* fixing action and state dim

* fix warnings with qwen processor kwargs

* fixing wm_loss not propagating

* adjusting obs steps, tublets size to match original implementation

* some more fixes to be closer to the original implem

* adding more tests to ensure good coverage

* align VLA-JEPA architecture with original checkpoint

- Remove stale `action_num_heads` / `action_attention_head_dim` config fields;
  DiT head dimensions are now always derived from the preset (DiT-B/L/test).
- Add `num_target_vision_tokens` and `action_max_seq_len` config fields required
  by the action head's future-token embedding and positional embedding tables.
- Fix default `qwen_model_name` to 2B (matches all released checkpoints).
- Rename `ActionEncoder` attrs w1/w2/w3 → layer1/layer2/layer3 to match
  checkpoint key names; replace `nn.Sequential` decoder/state-encoder with
  `_MLP2` (layer1/layer2 naming).
- Fix `VLAJEPAActionHead` to size ActionEncoder and StateEncoder at `inner_dim`
  (DiT input width) rather than `action_hidden_size` (DiT output width).
- Rename `DiT.blocks` → `transformer_blocks` and `attn` → `attn1` to match
  checkpoint; add alternating cross/self attention (even blocks cross-attend to
  Qwen context, odd blocks self-attend).
- Add `DiT-test` preset for unit tests.
- Rewrite `ActionConditionedVideoPredictor` with explicit ViT-style blocks
  (`_PredictorBlock` with fused qkv) to match checkpoint structure; rename
  `encoder`/`norm`/`proj` → `predictor_blocks`/`predictor_norm`/`predictor_proj`.

* propagate action_is_pad masking through VLA-JEPA policy pipeline

Pass the `action_is_pad` tensor from the batch through to the action head
so padded timesteps are excluded from the flow-matching loss.

* update VLA-JEPA tests for arch changes and action_is_pad

- Switch conftest to use `action_model_type="DiT-test"` now that
  `action_num_heads` / `action_attention_head_dim` have been removed.
- Add action_head tests covering fully-padded loss (zero) and equivalence
  of action_is_pad=None vs all-zeros mask.
- Remove obsolete `test_native_to_lerobot_wm_only` test.

* add VLA-JEPA documentation

Covers architecture overview, pretrained checkpoints, config reference,
training/eval commands for LIBERO-10, and guidance on fine-tuning for
single-camera datasets.

* add one-shot script to convert ginwind/VLA-JEPA checkpoints to safetensors (will remove once migrated)

* make default params more aligned with paper and pretrained models
- adding possibility of freezing qwen backbone and world model
- added tests for weight loading

* trying out to re-init the action head to avoid pretraining dimension mismatch

* allow different state dim and action dim

* removing missleading future_action_window_size to just use chunk_size

* lots of changes to make existing weights work, need to massively refactor the pre and post processing

* refactoring into using pre and post processor

* pre-commit cleanup

* fixing doc defaults args

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adressing dtype zeros issue

* adding guard for diffusers

* fixing training and exal examples

* trying to close success rate gap

* fix qwen norm layer output libero eval is now as expected

* adding instructions for different embodiement + fixing some tests

* smol fix to avoid having default CPU device when training

* fixing misconception about multiview / singleview handling

* removing conversion script

* adding licences

* adding .mdx docs and shortening polivy_vla_jepa_README.md

* removing useless pre-processor

* cleanup

* removing swish in favor of silu

* adding configuration gripper index and threshold

* fixing simlink

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: ginwind <ginwind@mail.ustc.edu.cn>
2026-06-04 19:22:51 +02:00
Jaimin d1b1c5c8cf docs: fix broken dataset script paths (datasets/v30 -> scripts) (#3695)
The docs pointed at src/lerobot/datasets/v30/, which does not exist.
Both scripts actually live in src/lerobot/scripts/:

- convert_dataset_v21_to_v30.py
- augment_dataset_quantile_stats.py

Updated the four references (one python -m module path and three
file-path invocations) to the correct location, matching each
script's own usage docstring.
2026-06-03 14:48:19 +02:00
Nikodem Bartnik 741c2d0a39 Docs/add lelab (#3707)
* first text draft (no images)

* simplified docs

* fix formatting

* add youtube video

* add a tip about compatibility

* fix broken link
2026-06-03 14:22:05 +02:00
Haoming Song 19fe315971 fix(train): enable relative action overrides for pretrained processors (#3711)
* fix(train): enable relative action overrides for pretrained processors
Keep pretrained processor pipelines when use_relative_actions is enabled and
apply relative/absolute action processor settings through overrides. Rename the
relative action processor registry key to relative_actions_processor.

* fix(config): reject rename_map without pretrained checkpoint

Fail fast when rename_map is set during fresh initialization, since fresh
configs derive feature names from the current dataset and no rename is applied.

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-03 11:46:35 +02:00
Khalil Meftah 906b585826 fix(datasets): default private to None in push_to_hub to respect Hub org visibility settings (#3713) 2026-06-02 19:25:13 +02:00
Khalil Meftah b8ad81bf39 feat(rewards): add ROBOMETER reward model (#3627)
* feat/add ROBOMETER reward model

* feat(rewards): add Robometer offline progress labeling script

* fix(rewards/robometer): add missing input keys mm_token_type_ids

* chore(rewards/robometer): default to lerobot/Robometer-4b model

* doc(rewards/robometer): update citation and original github link

* feat(rewards/robometer): add image key argument to compute Robometer progress
2026-05-29 21:45:39 +02:00
Haoquan Fang 24017e960c Add MolmoAct2 policy (#3604)
* add molmoact2 policy

* add apache headers to molmoact2 files

* simplify molmoact2 package imports

* align molmoact2 feature validation with eo pattern

* remove molmoact2 processor override from factory

* guard molmoact2 transformers imports

* guard molmoact2 processor transformers import

* add scipy dependency to molmoact2 extra

* use a single molmoact2 action queue

* move molmoact2 config logic into config

* fix molmoact2 hf image key resolution

* load molmoact2 without remote code

* lazy import molmoact2 scipy

* format molmoact2 files

* skip molmoact2 tests without optional deps

* fix molmoact2 pre-commit checks

* validate molmoact2 gripper range
2026-05-27 18:58:37 +02:00
Khalil Meftah e86f5af5bf feat(rewards): add TOPReward reward model (#3629)
* feat(rewards): add TOPReward reward model

* refactor(rewards): clean up TOPReward processor/model

* fix(rewards/topreward): add missing input keys mm_token_type_ids

* fix(rewards/topreward): fix pyproject extra typo and simplify processor (#3653)

Add lerobot[topreward] extra to all in
pyproject.toml, drop the redundant labels arg in scoring, and
collapse the dead-branch shape check in the encoder processor.

* optmize topreward input processing (#3660)

---------

Co-authored-by: Cole <91766445+jcoleharrison@users.noreply.github.com>
Co-authored-by: Haoming Song <haomingsong24@gmail.com>
2026-05-27 14:24:31 +02:00
Haoming Song 5c98e80430 fix(gr00t): fix Eagle25VL model and processor crash in transformers>=5.4.0, <5.6.0 (#3652)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-26 14:04:22 +02:00
Reece O'Mahoney f65f3f7a4a Fix policy.path in YAML configs (PR #3145 followup) (#3597)
PR #3145 added YAML support for policy.path but left two bugs:

1. extract_path_fields_from_config only deleted config_data[field] when
   no sibling overrides existed. With siblings, the dict stayed in place
   and draccus crashed decoding it as PreTrainedConfig (no 'type' key).
   Sibling overrides go into _config_yaml_overrides and are applied later
   by from_pretrained(), so the field can always be removed.

2. wrap() updated config_path_cli to the cleaned temp file path but
   never propagated it to the draccus.parse fallback branch. cli_args
   still contained --config_path=<original>, so draccus read the
   original YAML with path: still present.

Tests passed because they (a) called extract_path_fields_from_config
directly and (b) included type: alongside path: in the YAML, sidestepping
both bugs.

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-26 14:01:19 +02:00
Pepijn 8194897994 fix(deps): cap placo below 0.9.16 and harden kinematics import (#3647)
* fix(deps): cap placo below 0.9.16 and harden kinematics import

placo 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable
on Ubuntu 24.04 (noble ships urdfdom 3.x). Importing placo on that base
crashes with:

  ImportError: liburdfdom_sensor.so.4.0: cannot open shared object file

This broke nightly Latest Deps tests (CPU and GPU) when the lockfile
upgrade picked placo 0.9.16, since lerobot.model.kinematics
unconditionally imports placo when _placo_available is true, and that
check (importlib.util.find_spec) cannot detect dlopen failures of
transitive shared libraries — so unrelated subsystems (RL actor,
gym_manipulator) became unimportable.

Two changes:

1. Pin placo to <0.9.16 in pyproject.toml + regenerate uv.lock
   (0.9.16 → 0.9.15). Short-term unblock for nightly CI until system
   urdfdom 4.x is broadly available.

2. Harden the import guard in src/lerobot/model/kinematics.py:
   wrap 'import placo' in try/except ImportError so a missing
   transitive .so no longer crashes module import. RobotKinematics
   instantiation now raises an informative ImportError citing the
   underlying dlopen failure via _raise_if_placo_unusable().

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(kinematics): hoist _placo_runtime_error to module scope for mypy

Mypy walks the TYPE_CHECKING branch in which the runtime else-block is
not executed, so _placo_runtime_error was only defined at runtime and
mypy reported 'Name "_placo_runtime_error" is not defined' on the
three references inside _raise_if_placo_unusable. Declare the symbol
unconditionally at module scope with a default of None; the runtime
import-failure branch still assigns to it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* style(kinematics): drop verbose comments around placo import guard

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 12:03:07 +02:00
Haoming Song 9f437d86b6 fix(groot): align GR00TN15Config with transformers config dataclasses (#3606)
* fix(gr00t): fix gr00t config dataclass init TypeError

* fix(groot): guard strict config decorator without transformers for passing CI

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-05-22 10:31:04 +02:00
Haoming Song b74a551d38 fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash

* fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions

* fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete

* feat(test): add compile test and benchamrk for pi0 and pi05

* feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
2026-05-22 10:29:34 +02:00
Nikodem Bartnik c0a2e9814d fix examples (#3623)
- Fixed broken API examples in Lerobot Imitation Learning Documentation
- Teleoperation with cameras improved by adding a fixed frequency in the loop (without it the cameras feed gets very slow)
- Wrapped record example script in main() to avoid problems on Mac
- Previously teleoperation example was using SO-ARM and teleoperation with cameras was using Koch. I changed it to use SO-ARM in all of the examples.
- Added section on how to train with HF Jobs - CLI and Python examples
- Replaced lerobot-record with lerobot-rollout in policies examples
2026-05-21 22:14:07 +02:00
Khalil Meftah bac4f61eae refactor: support custom progress parquet overlays (#3640) 2026-05-21 14:32:10 +02:00
Virgileboat f4b834844e Feat/clean can bus (#3526)
* change timeout  for handshake

* enforce last state read when querry

* change import order

* fix(motors): flush stale robstride RX and harden feedback drain

* robstride: remove redundant timeout and max_messages casts

* bugfix + %-style

* update exception catch
2026-05-21 11:44:04 +02:00
Roham Z. Nobari dfdc48a7f1 fix(datasets): bound VideoDecoderCache to prevent OOM on large datasets (#3614)
VideoDecoderCache used an unbounded dict keyed on absolute path, with no
eviction in the standard LeRobotDataset path. With shuffled iteration over
datasets that have many distinct mp4 files, every DataLoader worker
accumulated one cached (VideoDecoder, fsspec file handle) pair per distinct
path it had ever touched. Per-entry cost is ~3-5 MB of host RAM plus one
open FD; at ~8 k entries this is roughly 30 GB per worker.

This was hit in the wild during a SmolVLA training run on a 4,195-episode
SO-101 dataset (8,390 mp4s, two cameras per episode). dmesg showed
anon-rss climbing to 34.9 GB on a single pt_data_worker before the OOM
killer fired ~30 min into training; with --num_workers=8 the per-worker
peak halved to 17.9 GB, which is the expected inverse-scaling signature
when the leak is per-decode and the workload is split across workers. The
working workaround on the affected platform was --dataset.video_backend=pyav,
because the pyav path opens/closes per call and never touches this cache.

Switch the backing store to an OrderedDict and evict LRU entries when the
cap is reached, closing the evicted file handle inside the lock so we do
not leak FDs either. Default cap is DEFAULT_DECODER_CACHE_SIZE = 100,
overridable via LEROBOT_VIDEO_DECODER_CACHE_SIZE or by passing max_size=
to the constructor; max_size=None restores the legacy unbounded behaviour
for callers that need it.

Validation on the original failing workload (decode_video_frames_torchcodec
called over real mp4s from the affected SO-101 dataset):

  unbounded:    300 files  ->  +1087 MB host RSS,  cache=300, still climbing
  cap=50:       500 files  ->   +266 MB host RSS,  cache=50,  stable
  cap=50:      2000 calls  ->   +312 MB host RSS,  cache=50,  stable
  cap=100:     1000 calls  ->   +470 MB host RSS,  cache=100, stable

Three independent seeded runs at cap=50 agreed to within 1% (263 / 266 /
265 MB delta), and the 2000-call multi-pass run shows RSS plateaus after
the cap is reached instead of drifting.

Tests in tests/datasets/test_video_decoder_cache.py cover:
default-is-bounded, size cap, LRU ordering, FD close on eviction, FD close
on clear(), cache-hit invariance, max_size=None fallback, and env-var
override. No regressions in test_video_encoding.py, test_streaming.py, or
test_dataset_reader.py (73 prior tests still pass alongside the 8 new ones).
2026-05-19 16:54:25 +02:00
四七 6a8878a639 fix(datasets): normalize shape=(1,) numeric values before HF encoding (#3344)
* fix(datasets): normalize shape=(1,) numeric values before save

* test(datasets): cover shape=(1,) int/bool and finalize

Co-authored-by: Copilot <copilot@github.com>
2026-05-19 16:53:19 +02:00
117 changed files with 25671 additions and 1033 deletions
+4
View File
@@ -22,6 +22,10 @@ outputs
rl
media
# Local virtualenvs (the image provides its own)
.venv
venv
# Logging
logs
+10
View File
@@ -9,6 +9,8 @@
- sections:
- local: il_robots
title: Imitation Learning for Robots
- local: lelab
title: LeLab - Lerobot GUI
- local: bring_your_own_policies
title: Adding a Policy
- local: integrate_hardware
@@ -59,8 +61,14 @@
title: π₀-FAST (Pi0Fast)
- local: pi05
title: π₀.₅ (Pi05)
- local: molmoact2
title: MolmoAct2
- local: vla_jepa
title: VLA-JEPA
- local: eo1
title: EO-1
- local: lingbot_va
title: LingBot-VA
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
@@ -73,6 +81,8 @@
- sections:
- local: sarm
title: SARM
- local: robometer
title: ROBOMETER
- local: topreward
title: TOPReward
title: "Reward Models"
+6 -10
View File
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
```bash
lerobot-record \
--robot.type=so100_follower \
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/act_policy \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--robot.id=my_robot \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--display_data=true \
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
--dataset.num_episodes=10 \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=${HF_USER}/act_policy
--task="Your task description" \ # can be skipped for ACT
--duration=60
```
+5 -5
View File
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
```bash
lerobot-record \
lerobot-rollout\
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--robot.type=bi_so_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
@@ -119,14 +121,12 @@ lerobot-record \
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10
--duration=600
```
## License
+210 -108
View File
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem58760431541",
id="my_red_robot_arm",
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
)
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
robot = SO101Follower(robot_config)
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
<hfoption id="Command">
```bash
lerobot-teleoperate \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--teleop.type=koch_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=my_awesome_leader_arm \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem5AB90687491 \
--robot.id=my_follower_arm \
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem5AB90689011 \
--teleop.id=my_leader_arm \
--display_data=true
```
</hfoption>
@@ -122,34 +122,48 @@ lerobot-teleoperate \
<!-- prettier-ignore-start -->
```python
import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
camera_config = {
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
}
robot_config = KochFollowerConfig(
port="/dev/tty.usbmodem585A0076841",
id="my_red_robot_arm",
cameras=camera_config
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
teleop_config = KochLeaderConfig(
port="/dev/tty.usbmodem58760431551",
id="my_blue_leader_arm",
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
robot = KochFollower(robot_config)
teleop_device = KochLeader(teleop_config)
init_rerun(session_name="teleoperation")
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
robot.connect()
teleop_device.connect()
TARGET_HZ = 30
TIME_PER_FRAME = 1.0 / TARGET_HZ
while True:
start_time = time.perf_counter()
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_rerun_data(observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
if sleep_time > 0:
time.sleep(sleep_time)
```
<!-- prettier-ignore-end -->
@@ -202,10 +216,11 @@ lerobot-record \
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
# Create robot configuration
robot_config = SO100FollowerConfig(
id="my_awesome_follower_arm",
cameras={
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
},
port="/dev/tty.usbmodem58760434471",
)
teleop_config = SO100LeaderConfig(
id="my_awesome_leader_arm",
port="/dev/tty.usbmodem585A0077581",
)
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop = SO100Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
def main():
# Create robot configuration
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem5AB90689011",
id="my_leader_arm",
)
# Initialize the robot and teleoperator
robot = SO101Follower(robot_config)
teleop = SO101Leader(teleop_config)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
teleop.connect()
# Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop(
robot=robot,
events=events,
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
dataset.save_episode()
episode_idx += 1
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
dataset.save_episode()
episode_idx += 1
# finalize dataset
log_say("Finalizing dataset...")
dataset.finalize()
# Clean up
log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
if __name__ == "__main__":
main()
```
<!-- prettier-ignore-end -->
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
##### 2. Checkpointing and Resuming
- Checkpoints are automatically created during recording.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
- To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
episode_idx = 0
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
robot = SO100Follower(robot_config)
robot.connect()
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Train using Hugging Face Jobs
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
To run the training use this command:
<hfoptions id="train_with_hf_jobs">
<hfoption id="Command">
```bash
hf jobs run \
--flavor a10g-small \
--timeout 4h \
--secrets HF_TOKEN \
huggingface/lerobot-gpu:latest \
-- \
python -m lerobot.scripts.lerobot_train \
--dataset.repo_id=username/dataset \
--policy.type=act \
--steps=5000 \
--batch_size=16 \
--policy.device=cuda \
--policy.repo_id=username/your_policy \
--log_freq=100
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from huggingface_hub import run_job, get_token
run_name = "act_so101_hf_jobs"
dataset_id = "username/dataset"
user_hub_id = "username"
command_args = [
"python", "-m", "lerobot.scripts.lerobot_train",
"--dataset.repo_id", dataset_id,
"--policy.type", "act",
"--steps", "5000",
"--batch_size", "16",
"--num_workers", "4",
"--policy.device", "cuda",
"--log_freq", "100",
"--save_freq", "1000",
"--save_checkpoint", "true",
"--wandb.enable", "false",
"--policy.repo_id", f"{user_hub_id}/{run_name}"
]
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
job_info = run_job(
image="huggingface/lerobot-gpu:latest",
command=command_args,
flavor="a10g-small",
timeout="4h",
secrets={"HF_TOKEN": get_token()}
)
print("\n🚀 Job successfully launched!")
print(f"🔹 Job ID: {job_info.id}")
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
For longer training sessions increase the timeout.
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:
@@ -546,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
- `episodic`: Episode-oriented policy recording with reset phases between episodes
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+38
View File
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
### Episodic (`--strategy.type=episodic`)
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
```bash
lerobot-rollout \
--strategy.type=episodic \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so100_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/my_eval_data \
--dataset.num_episodes=20 \
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10 \
--dataset.single_task="Pick up the red cube"
```
Teleop is optional — if omitted the robot holds its position during the reset phase.
**Keyboard controls:**
| Key | Action |
| ----------- | -------------------------------- |
| `→` (right) | End the current episode early |
| `←` (left) | Discard episode and re-record it |
| `ESC` | Stop the recording session |
| Flag | Description |
| ----------------------------------------------- | -------------------------------------------------------------------------- |
| `--dataset.num_episodes` | Number of episodes to record |
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
---
## Inference Backends
+29
View File
@@ -0,0 +1,29 @@
# LeLab - LeRobot Guide
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
> [!TIP]
> For now LeLab is compatible only with SO-ARM101
<Youtube id="VqyKUuW9V1g" />
### Installation
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
```
uv tool install git+https://github.com/huggingface/leLab.git && lelab
```
After install, run `lelab` from your terminal anytime to start the app.
### Features
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
+1 -1
View File
@@ -275,7 +275,7 @@ A converter aggregates perepisode files into larger shards and writes episode
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
# Convert an existing v2.1 dataset hosted on the Hub:
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
```
**What it does**
+187
View File
@@ -0,0 +1,187 @@
# LingBot-VA
LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2**
video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of
future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot
integration wires LingBot-VA into the standard training, evaluation and processor
interfaces.
## Model Overview
LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
(`patch_embedding_mlp → blocks → proj_out`) and an action stream
(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and
text conditioning.
| Component | Class | Role |
| ------------------------ | ----------------------- | ----------------------------------------------------------- |
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer. |
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are
fed back into the KV cache as the chunk is executed (closed-loop world modeling).
### What the LeRobot Integration Covers
- Standard `policy.type=lingbot_va` configuration through LeRobot.
- Ready-to-use LeRobot-format checkpoints on the Hub (converted from the released upstream ones).
- Autoregressive dual-stream inference behind the standard `select_action` interface
(single-environment eval, `--eval.batch_size=1`).
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
- Evaluation with `lerobot-eval` on LIBERO and RoboTwin.
- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below.
## Installation
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install the LingBot-VA extra:
```bash
pip install -e ".[lingbot_va]"
```
## Checkpoints
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
| Variant | LeRobot checkpoint |
| ---------------------- | -------------------------------- |
| LIBERO-Long post-train | `lerobot/lingbot_va_libero_long` |
| RoboTwin post-train | `lerobot/lingbot_va_robotwin` |
| Pretrained base | `lerobot/lingbot_va_base` |
Only the trainable ~5B transformer is stored in the LeRobot
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are pulled from
`config.wan_pretrained_path` at load time (defaults to the source `robbyant/*` repo). The
UMT5-XXL text encoder runs on CPU by default (`config.text_encoder_device`) so the 5B
transformer + VAE fit on a single 2432 GB GPU.
## Evaluation (LIBERO)
```bash
lerobot-eval \
--policy.path=lerobot/lingbot_va_libero_long \
--policy.device=cuda \
--env.type=libero --env.task=libero_10 \
--env.observation_height=128 --env.observation_width=128 \
--eval.n_episodes=50 --eval.batch_size=1 \
--output_dir=outputs/eval/lingbot_va_libero
```
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
single-environment eval; use `--eval.batch_size=1`.
## Evaluation (RoboTwin)
RoboTwin 2.0 needs the SAPIEN + CuRobo simulator stack. You can use the benchmark Docker image
(`docker/Dockerfile.benchmark.robotwin`, which also needs `warp-lang==1.3.1` and CuRobo built
with the GPU's compute capability in `TORCH_CUDA_ARCH_LIST`). RoboTwin uses **end-effector-pose
control**, so run with `--env.action_mode=ee`: the policy predicts per-arm `xyz+quaternion+gripper`
deltas (`robotwin_tshape` latent layout) that are composed onto the episode's initial eef pose and
executed via CuRobo IK.
```bash
lerobot-eval \
--policy.path=lerobot/lingbot_va_robotwin \
--policy.device=cuda \
--env.type=robotwin --env.task=beat_block_hammer --env.action_mode=ee \
--eval.n_episodes=10 --eval.batch_size=1 \
--output_dir=outputs/eval/lingbot_va_robotwin
```
### Saving predicted (imagined) videos
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
The same flag works for the periodic eval during `lerobot-train`.
## Training / fine-tuning
`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss
(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes
the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the
transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW
with a linear-warmup-then-constant schedule (matching upstream).
Requirements:
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
- The full 5B DiT does not fit a single 2432 GB GPU under AdamW; fine-tune with **LoRA**
(`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the
trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen.
```bash
lerobot-train \
--policy.path=lerobot/lingbot_va_libero_long --policy.attn_mode=flex \
--policy.use_peft=true \
--dataset.repo_id=<your LeRobot-format dataset> \
--batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va
```
The dataset must provide camera clips (a temporal window per camera, VAE-encoded to
`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item.
## Data format (action channels & camera order)
LingBot-VA is an **end-effector (Cartesian) pose** policy, it predicts EEF poses + gripper, not
joint positions. Actions live in a fixed multi-embodiment **30-dim** layout; map your robot's
action dimensions into these channels and pad the rest with `0` (`used_action_channel_ids` selects
the channels a given checkpoint actually uses):
| channels | meaning |
| -------- | ----------------------------------------------------- |
| 06 | Left-arm end-effector pose |
| 713 | Right-arm end-effector pose |
| 1420 | Left-arm joints (unused by the released checkpoints) |
| 2127 | Right-arm joints (unused by the released checkpoints) |
| 28 | Left gripper |
| 29 | Right gripper |
- **LIBERO** uses channels `06`: a 6-DoF EEF delta (xyz + rotation) + gripper (single arm).
- **RoboTwin** uses channels `[06, 28, 713, 29]`: left EEF (xyz + quaternion) + left gripper +
right EEF + right gripper (16 dims). The env converts these poses to joint trajectories via
CuRobo IK — joints are never predicted.
Joint-space datasets (or a different EEF convention) must be remapped into this schema before
fine-tuning these checkpoints.
**Camera order is fixed and order-sensitive**, per-camera latents are concatenated spatially in
`obs_cam_keys` order, so the physical camera→slot mapping must match training:
| benchmark | `obs_cam_keys` (in order) | `camera_layout` |
| --------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------- |
| LIBERO | `observation.images.image` (agentview / 3rd-person), `observation.images.image2` (eye-in-hand wrist) | `width_concat` (latents concatenated on width) |
| RoboTwin | `observation.images.head_camera`, `observation.images.left_camera`, `observation.images.right_camera` | `robotwin_tshape` (full-res head below, two half-res wrists on top) |
The first camera is the exterior/head view and the rest are wrist views.
## Inference Hyperparameters (LIBERO)
| Key | Value |
| -------------------------------------- | --------------------------------------------------------------------------------- |
| height × width | 128 × 128 |
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
| action channels used | 06 (7-DoF arm + gripper) |
| action_per_frame / frame_chunk_size | 4 / 4 |
| attn_window | 30 |
| video / action denoising steps | 20 / 50 |
| guidance_scale / action_guidance_scale | 5 / 1 |
| snr_shift / action_snr_shift | 5.0 / 0.05 |
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
## Notes
- **Attention backend:** inference uses the `torch` SDPA backend (always available). The
`flashattn` and `flex` backends are optional; `flex` is only needed for training.
- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs
roughly 1824 GB of VRAM.
## License
LingBot-VA is released under Apache-2.0. See the
[upstream repository](https://github.com/Robbyant/lingbot-va).
+433
View File
@@ -0,0 +1,433 @@
# MolmoAct2 Policy
MolmoAct2 is the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
training, evaluation, checkpointing, and dataset interfaces for easier use with
LeRobot datasets.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## Installation Requirements
Install LeRobot with the MolmoAct2 optional dependencies:
```bash
pip install -e ".[molmoact2]"
```
To run the models in this repository, you need an NVIDIA GPU. The measurements
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
`gradient_checkpointing=true` and include the forward pass, backward pass,
gradient clipping, optimizer step, and optimizer state allocation. Values are
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
dataloader workers, CUDA context, and fragmentation.
Multi-GPU training through `accelerate` increases throughput and global batch
size, but this LeRobot port does not currently expose the original MolmoAct2
`fsdp_devices` model-parallel training path. The current training script has
not been tested for multi-node training.
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
The repo has been tested with Ubuntu 22.04.
## Usage
To use MolmoAct2 in a LeRobot training config, set:
```python
policy.type=molmoact2
```
## Training
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
the same LeRobot training loop, dataset transforms, checkpoint saving, and
logging. The difference is only how the initial policy weights and processor
state are loaded.
### Training With Original MolmoAct2 Weight
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
the original HF model files, then build its own policy processor from the
dataset metadata and the policy options below.
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
augmentation, and LeRobot's checkpointing/logging path.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.setup_type="single franka robotic arm in libero" \
--policy.control_mode="delta end-effector pose" \
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--policy.freeze_embedding=true \
--policy.normalize_gripper=false \
--policy.enable_knowledge_insulation=false \
--policy.push_to_hub=false \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Training With LeRobot MolmoAct2 Weight
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
restores the saved LeRobot policy config, model weights, processor, and
normalization statistics. You can still override training-time options such as
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
```bash
accelerate launch \
--num_processes=8 \
--mixed_precision=bf16 \
-m lerobot.scripts.lerobot_train \
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--policy.path=/path/to/pretrained_model \
--policy.device=cuda \
--policy.action_mode=both \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.model_dtype=bfloat16 \
--policy.num_flow_timesteps=8 \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--wandb.entity=<wandb_entity> \
--wandb.project=<wandb_project> \
--job_name=<job_name> \
--output_dir=outputs/<job_name> \
--steps=10000 \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
### Common Practices
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
or a real-world dataset with less than 200 demonstrations, a global batch size of
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
cases, we intentionally keep the action expert fully trainable, which we found
to be crucial for model performance. For larger fine-tuning datasets, larger
global batch sizes and full fine-tuning are usually preferred.
### Common Policy Options
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
Use this for released MolmoAct2 weights.
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
created by LeRobot training.
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
`both`. `both` trains the flow-matching action expert and the discrete
action-token loss.
- `policy.train_action_expert_only`: trains only parameters whose names contain
`action_expert`. It requires `policy.action_mode=continuous`.
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
expert linear layers. When `policy.enable_lora_action_expert=false`, the
action expert base weights remain fully trainable while the VLM is trained
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
action expert is also adapter-tuned instead of fully fine-tuned.
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
context K/V states before the action loss. The default is `false`.
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
`10`. This LeRobot port overrides the loaded checkpoint's
`max_action_horizon` with this value.
- `policy.n_action_steps`: number of actions consumed from each predicted
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
- `policy.setup_type`: text inserted into the prompt to describe the robot and
scene, e.g. `single franka robotic arm in libero`. More examples are listed
in the `metadata_by_tag` entries of
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
- `policy.control_mode`: text inserted into the prompt to describe the action
space, e.g. `delta end-effector pose` or `absolute joint pose`.
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
processor.
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
example during training. We use `8` for fine-tuning.
- `policy.num_inference_steps`: optional override for continuous action
generation steps at inference time.
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
to reduce activation memory.
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
- `policy.normalize_gripper`: controls whether gripper dimensions are included
in state/action quantile normalization. The default is `false`.
- `policy.normalize_language`: normalizes task strings before prompt
construction. The default is `true`.
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
Released checkpoints use `policy.expected_max_action_dim=32`.
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
infer it from images, state dimension, action dimension, action horizon, and
discrete-action mode.
### Learning Rates
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
fine-tuning experiments.
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
`policy.optimizer_vit_lr=5e-6` for the vision tower,
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
`policy.enable_lora_action_expert=false`, so the action expert is still fully
fine-tuned with `policy.optimizer_action_expert_lr`. If
`policy.enable_lora_action_expert=true`, the action expert is trained through
LoRA adapters instead.
- Action-expert-only fine-tuning trains only the action expert and uses
`policy.optimizer_action_expert_lr=5e-5`.
You can override the full fine-tuning and action-expert learning rates with
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
### Dataset Quantile Statistics
MolmoAct2 defaults to quantile normalization for state and action features. If
your dataset has not been converted with quantile statistics, you can add them
with:
```bash
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
--repo-id=your_dataset
```
Alternatively, train MolmoAct2 with mean/std normalization:
```bash
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
```
## Evaluation
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
fixed and use `policy.per_episode_seed=true`.
**Important:** We found that `num_steps_wait=10` does not reliably let the
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
results reported here use `num_steps_wait=50`.
### Evaluation With LeRobot MolmoAct2 Weight
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
normalization statistics are restored together with the model.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
--policy.inference_action_mode=continuous \
--policy.model_dtype=bfloat16 \
--policy.use_amp=true \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
### Evaluation With Original MolmoAct2 Weight
You can evaluate a released Hugging Face checkpoint directly without first
converting it to a LeRobot checkpoint. In this case, set
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
normalization statistics, action horizon, prompt metadata, and image-key order
from the checkpoint's `norm_stats.json`.
To fully replicate the MolmoAct2 paper results with released Hugging Face
checkpoints, we recommend using the v0.5.1-pinned
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
branch. That branch matches the original evaluation settings used for the
reported numbers.
```bash
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
lerobot-eval \
--policy.type=molmoact2 \
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
--policy.norm_tag=libero \
--policy.inference_action_mode=continuous \
--policy.model_dtype=float32 \
--policy.use_amp=false \
--policy.enable_inference_cuda_graph=true \
--policy.device=cuda \
--policy.per_episode_seed=true \
--policy.eval_seed=1000 \
--env.type=libero \
--env.task=libero_goal \
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=1000
```
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
full LIBERO suite. The same command works for other released MolmoAct2
checkpoints as long as the requested `policy.norm_tag` exists in that
checkpoint's `norm_stats.json`.
### Common Evaluation Options
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
flow-matching inference or `discrete` for action-token inference. It must be
compatible with the training-time `policy.action_mode` saved in the
checkpoint.
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
saved by LeRobot.
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
image-key order, and action horizon from the original checkpoint's
`norm_stats.json`. It is required for direct original-HF checkpoint
evaluation.
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
- `policy.use_amp`: runs the policy forward under autocast during eval. For
`model_dtype=bfloat16`, keep this enabled.
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
graph path for faster repeated continuous-action rollout.
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
action generation deterministic per episode for replication.
- `env.task`: comma-separated LIBERO suites or a single suite. Use
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
by the policy processor.
## Performance Results
### LIBERO Benchmark Results
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
compare and test its LeRobot implementation, we fine-tuned
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
results.
The LeRobot fine-tuned checkpoint reported here is available at
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
and was trained on
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
| -------------- | ---------------------: | -----------------: |
| LIBERO Spatial | 98.4% | 97.8% |
| LIBERO Object | 100.0% | 100.0% |
| LIBERO Goal | 98.0% | 97.8% |
| LIBERO 10 | 96.6% | 93.2% |
| Average | 98.25% | 97.20% |
These results demonstrate MolmoAct2's strong performance across diverse robotic
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
evaluation section.
## Differences From the Original Implementation
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
differences from the original training repository are:
- The original paper training stack loads the model in fp32 and trains under
mixed precision. This LeRobot port usually loads the checkpoint directly in
`policy.model_dtype=bfloat16` for lower memory use.
- The original repository uses its own FSDP/model-parallel training path. The
LeRobot port uses the standard LeRobot/Accelerate training path and has not
been tested for multi-node training.
- The original repository supports sequence packing. The LeRobot port trains on
one LeRobot sample per item and pads to an inferred fixed sequence budget.
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
dataset transforms, image augmentation, and Weights & Biases logging
conventions.
- The original training path supports mixed action horizons by padding to
`max_action_horizon` and masking padded horizon slots in the action expert
self-attention. This is useful when training across datasets with different
control frequencies. The LeRobot port currently targets single-dataset
fine-tuning, so `policy.chunk_size` overrides the checkpoint
`max_action_horizon` and horizon masking is not implemented yet. Support for
this mixed-horizon path is planned.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
+1 -1
View File
@@ -91,7 +91,7 @@ lerobot-train \
If your dataset is not converted with `quantiles`, you can convert it with the following command:
```bash
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
--repo-id=your_dataset \
```
+39
View File
@@ -0,0 +1,39 @@
# MolmoAct2
This repository contains the LeRobot policy implementation of
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
training, evaluation, checkpointing, and dataset compatibility.
This implementation currently supports training and evaluation for the regular
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
not included in this LeRobot policy yet and is coming soon.
For the original MolmoAct2 training code used for the experiments reported in
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
## LIBERO Evaluation
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
scene stabilize and can degrade measured success. All LIBERO evaluation results
reported for this LeRobot implementation use `num_steps_wait=50`.
## Citation
```bibtex
@misc{fang2026molmoact2actionreasoningmodels,
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
year={2026},
eprint={2605.02881},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2605.02881},
}
```
## License
This model is licensed under Apache 2.0. It is intended for research and
educational use in accordance with
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
+39
View File
@@ -0,0 +1,39 @@
# VLA-JEPA
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
---
## Architecture Overview
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
At inference time only the Qwen backbone and action head are used; the world model is not needed.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
+1 -1
View File
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
If you have existing datasets in v2.1 format, use the migration tool:
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
--repo-id your_id/existing_dataset
```
+185
View File
@@ -0,0 +1,185 @@
# ROBOMETER
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
**Project**: [robometer.github.io](https://robometer.github.io/)
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
## Overview
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
- **Success head**: predicts per-frame task success probability.
- **Preference head**: predicts which of two trajectories better completes the task during training.
The paper trains ROBOMETER with a composite objective:
```text
L = L_pref + L_prog + L_succ
```
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
## What the LeRobot Integration Covers
- Standard `reward_model.type=robometer` configuration through LeRobot.
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
- Dense, frame-level progress and success predictions internally.
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
## Installation Requirements
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install the ROBOMETER dependencies:
```bash
pip install -e ".[robometer]"
```
If you use `uv` directly from a source checkout:
```bash
uv sync --extra robometer
```
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
## Model Inputs and Outputs
ROBOMETER expects:
- A trajectory video or sequence of frames.
- A natural-language task description.
In LeRobot datasets, the preprocessor reads:
| Config field | Default | Meaning |
| ------------------------- | ------------------------ | ----------------------------------------------------- |
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
## Usage
### Load the Reward Model Directly
```python
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
cfg = RobometerConfig(
pretrained_path="lerobot/Robometer-4B",
device="cuda",
reward_output="progress",
)
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
```
### Encode Frames and Compute a Reward
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
```python
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
# task: str
encoder = RobometerEncoderProcessorStep(
base_model_id=cfg.base_model_id,
use_multi_image=cfg.use_multi_image,
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
max_frames=cfg.max_frames,
)
encoded = encoder.encode_samples([(frames, task)])
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
reward = reward_model.compute_reward(batch)
```
`reward` is a tensor of shape `(batch_size,)`.
### Use the Reward Factory
You can also instantiate ROBOMETER through the reward factory:
```python
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
cfg = make_reward_model_config(
"robometer",
pretrained_path="lerobot/Robometer-4B",
device="cuda",
image_key="observation.images.top",
)
reward_model = make_reward_model(cfg)
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
```
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
## Configuration Notes
### Backbone and Vocabulary
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
```text
<|split_token|>
<|reward_token|>
<|pref_token|>
<|sim_token|>
<|prog_token|>
```
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
### Progress Prediction
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
### Success Prediction
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
## Limitations
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
## References
- [ROBOMETER project](https://robometer.github.io/)
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
- [Original ROBOMETER code](https://github.com/robometer/robometer)
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
## Citation
```bibtex
@inproceedings{liang2026robometer,
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
year={2026},
booktitle={Robotics: Science and Systems 2026},
}
```
## License
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
+8 -8
View File
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
Once you are logged in, you can run inference in your setup by doing:
```bash
lerobot-record \
lerobot-rollout \
--strategy.type=base \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
--dataset.episode_time_s=50 \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
# <- RTC optional, use when running on low power hardware \
# --inference.type=rtc \
# --inference.rtc.execution_horizon=10 \
# --inference.rtc.max_guidance_weight=10.0 \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_red_leader_arm \
# --display_data=true #optional use if you want to see the camera stream \
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
```
+235
View File
@@ -0,0 +1,235 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Default behaviour — view padding / trimming (no action required)**
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
+37 -15
View File
@@ -15,10 +15,12 @@
# limitations under the License.
"""
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result.
The progress data is read from a parquet file that lives alongside the dataset
(configurable via ``--progress-file``).
Usage:
python examples/dataset/create_progress_videos.py \
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode_metadata(repo_id: str, episode: int) -> Path:
"""Download only the metadata and sarm_progress files for a dataset.
def download_episode_metadata(
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
) -> Path:
"""Download only the metadata and per-frame progress file for a dataset.
Args:
repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched).
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Local cache path for the downloaded snapshot.
"""
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
local_path = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"],
allow_patterns=["meta/**", progress_file],
ignore_patterns=["*.mp4"],
)
)
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
return video_path
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for an episode.
def load_progress_data(
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
) -> np.ndarray | None:
"""Load per-frame progress values for an episode.
Args:
local_path: Dataset cache root.
episode: Episode index.
progress_file: Filename of the per-frame progress parquet.
Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
"""
parquet_path = local_path / "sarm_progress.parquet"
parquet_path = local_path / progress_file
if not parquet_path.exists():
logging.warning("sarm_progress.parquet not found")
logging.warning("%s not found", progress_file)
return None
df = pd.read_parquet(parquet_path)
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
logging.info(" %s columns: %s", progress_file, list(df.columns))
episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty:
logging.warning("No sarm_progress rows for episode %d", episode)
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
return None
episode_df = episode_df.sort_values("frame_index")
@@ -576,6 +585,7 @@ def process_dataset(
camera_key: str | None,
output_dir: Path,
create_gif: bool = False,
progress_file: str = "sarm_progress.parquet",
) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output.
@@ -585,6 +595,8 @@ def process_dataset(
camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4.
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns:
Path to the final output file, or None on failure.
@@ -592,7 +604,7 @@ def process_dataset(
safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode)
local_path = download_episode_metadata(repo_id, episode, progress_file)
logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key)
@@ -600,9 +612,9 @@ def process_dataset(
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode)
progress_data = load_progress_data(local_path, episode, progress_file)
if progress_data is None:
logging.error("Could not load sarm_progress data. Skipping overlay.")
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
return None
logging.info(" Progress frames: %d", len(progress_data))
@@ -627,7 +639,7 @@ def process_dataset(
def main() -> None:
parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
)
parser.add_argument(
"--repo-id",
@@ -658,6 +670,15 @@ def main() -> None:
action="store_true",
help="Also generate a GIF from the MP4 output.",
)
parser.add_argument(
"--progress-file",
type=str,
default="sarm_progress.parquet",
help=(
"Filename of the per-frame progress parquet inside the dataset repo "
"(default: 'sarm_progress.parquet')."
),
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -670,6 +691,7 @@ def main() -> None:
camera_key=args.camera_key,
output_dir=args.output_dir,
create_gif=args.gif,
progress_file=args.progress_file,
)
if result:
+20 -3
View File
@@ -138,13 +138,16 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
# Common
av-dep = ["av>=15.0.0,<16.0.0"]
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.37.0"]
imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
pyserial-dep = ["pyserial>=3.5,<4.0"]
@@ -196,6 +199,7 @@ wallx = [
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
@@ -209,10 +213,13 @@ groot = [
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]", "accelerate>=1.10.0,<2.0.0", "ftfy>=6.0.0,<7.0.0"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -273,10 +280,13 @@ all = [
"lerobot[multi_task_dit]",
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[molmoact2]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[lingbot_va]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -287,6 +297,7 @@ all = [
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
@@ -367,6 +378,9 @@ ignore = [
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for
# torch.nn.functional.
"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"]
[tool.ruff.lint.isort]
combine-as-imports = true
@@ -403,8 +417,11 @@ default.extend-ignore-identifiers-re = [
"ein",
"thw",
"inpt",
"arange",
"is_compileable",
"ROBOTIS",
"OT_VALUE"
"OT_VALUE",
"VanderBilt"
]
# TODO: Uncomment when ready to use
+70
View File
@@ -18,6 +18,7 @@ from __future__ import annotations
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)
########################################################################################
# Teleoperator smooth handover helpers
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
# being a root module.
########################################################################################
def teleop_supports_feedback(teleop) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback (i.e. have non-empty
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
``target_pos`` is expected to be in the teleop's action/feedback key space.
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
the robot action key space directly.
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
follower robot does not receive new actions, which could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def follower_smooth_move_to(
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm to
the follower, the follower is brought to the teleop's current pose so the
robot meets the operator's hand rather than jumping to it on the first frame.
Both ``current`` and ``target`` must be in the robot action key space
(i.e. the output of ``robot_action_processor``).
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
+2 -2
View File
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
private: bool | None = None
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
+8 -3
View File
@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
remaining = config_data[field]
if remaining:
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
else:
del config_data[field]
del config_data[field]
modified = True
if not modified:
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
if config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = draccus.parse(
config_class=argtype,
config_path=config_path_cli or config_path,
args=cli_args,
)
response = fn(cfg, *args, **kwargs)
return response
+6
View File
@@ -177,6 +177,12 @@ class TrainPipelineConfig(HubMixin):
)
active_cfg = self.trainable_config
if self.rename_map and active_cfg.pretrained_path is None:
raise ValueError(
"`rename_map` requires a pretrained policy checkpoint. "
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{active_cfg.type}"
+8 -1
View File
@@ -250,7 +250,14 @@ class DatasetWriter:
for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])
stacked_values = np.stack(episode_buffer[key])
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
stacked_values = stacked_values.reshape(episode_length)
episode_buffer[key] = stacked_values
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
+3 -2
View File
@@ -524,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
license: str | None = "apache-2.0",
tag_version: bool = True,
push_videos: bool = True,
private: bool = False,
private: bool | None = None,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
**card_kwargs,
@@ -543,7 +543,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
tag_version: If ``True``, create a Git tag for the current codebase
version.
push_videos: If ``False``, skip uploading the ``videos/`` directory.
private: If ``True``, create a private repository.
private: If ``True``, create a private repository. If ``None``
(default), defer to the org default on the Hub (only affects orgs).
allow_patterns: Glob pattern(s) restricting which files to upload.
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
of ``upload_folder`` for very large datasets.
+87 -16
View File
@@ -17,11 +17,13 @@ import contextlib
import glob
import importlib
import logging
import os
import queue
import shutil
import tempfile
import threading
import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
from fractions import Fraction
from pathlib import Path
@@ -191,15 +193,70 @@ def decode_video_frames_pyav(
return closest_frames
class VideoDecoderCache:
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
DEFAULT_DECODER_CACHE_SIZE = 100
"""Default LRU capacity for :class:`VideoDecoderCache`.
def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
Sized to comfortably hold a small rolling window of episodes worth of decoders
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
to the constructor (``None`` restores the legacy unbounded behaviour).
"""
def _default_max_cache_size() -> int | None:
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
if raw is None:
return DEFAULT_DECODER_CACHE_SIZE
raw = raw.strip().lower()
if raw in ("", "none", "unbounded", "-1"):
return None
try:
value = int(raw)
except ValueError as e:
raise ValueError(
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
) from e
if value <= 0:
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
return value
class VideoDecoderCache:
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
backing it. When the cache is full and a new path is requested, the
least-recently-used entry is evicted and its file handle is closed. This
bounds host-RAM growth when iterating over datasets with many distinct
video files (otherwise each ``DataLoader`` worker pins every decoder it has
ever opened until the process exits).
Args:
max_size: Maximum number of decoders to retain. ``None`` disables
eviction and restores legacy unbounded behaviour. Defaults to the
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
:data:`DEFAULT_DECODER_CACHE_SIZE`.
"""
_SENTINEL: ClassVar[object] = object()
def __init__(self, max_size: int | None | object = _SENTINEL):
if max_size is VideoDecoderCache._SENTINEL:
max_size = _default_max_cache_size()
if max_size is not None and max_size <= 0:
raise ValueError(f"max_size must be positive or None; got {max_size}")
self.max_size: int | None = max_size # type: ignore[assignment]
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
self._lock = Lock()
def __contains__(self, video_path: object) -> bool:
with self._lock:
return str(video_path) in self._cache
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one."""
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
@@ -211,22 +268,36 @@ class VideoDecoderCache:
video_path = str(video_path)
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
entry = self._cache.get(video_path)
if entry is not None:
self._cache.move_to_end(video_path)
return entry[0]
return self._cache[video_path][0]
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
# Evict LRU entries until we are back under the cap. We close
# evicted file handles immediately; the associated ``VideoDecoder``
# is released to the GC when its last reference goes away.
if self.max_size is not None:
while len(self._cache) > self.max_size:
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
with contextlib.suppress(Exception):
evicted_handle.close()
return decoder
def clear(self):
"""Clear the cache and close file handles."""
"""Clear the cache and close all file handles."""
with self._lock:
for _, file_handle in self._cache.values():
file_handle.close()
with contextlib.suppress(Exception):
file_handle.close()
self._cache.clear()
def size(self) -> int:
+7 -1
View File
@@ -757,7 +757,7 @@ class RoboTwinEnvConfig(EnvConfig):
task: str = "beat_block_hammer" # single task or comma-separated list
fps: int = 25
episode_length: int = 300
episode_length: int = 1200
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
# Available cameras from RoboTwin's aloha-agilex embodiment: head_camera
@@ -768,6 +768,9 @@ class RoboTwinEnvConfig(EnvConfig):
# must equal what SAPIEN actually renders.
observation_height: int = 240
observation_width: int = 320
# "joint": 14-d joint-space control. "ee": 16-d end-effector-pose deltas executed via CuRobo IK
# (for world-model policies like LingBot-VA that predict per-arm xyz+quaternion+gripper poses).
action_mode: str = "joint"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
@@ -784,6 +787,8 @@ class RoboTwinEnvConfig(EnvConfig):
)
def __post_init__(self):
if self.action_mode == "ee":
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(16,))
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
for cam in cam_list:
self.features[f"pixels/{cam}"] = PolicyFeature(
@@ -826,6 +831,7 @@ class RoboTwinEnvConfig(EnvConfig):
observation_height=self.observation_height,
observation_width=self.observation_width,
episode_length=self.episode_length,
action_mode=self.action_mode,
)
+154 -6
View File
@@ -17,6 +17,7 @@ from __future__ import annotations
import importlib
import logging
import os
from collections import defaultdict
from collections.abc import Callable, Sequence
from functools import partial
@@ -41,10 +42,117 @@ ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
"right_camera",
)
ACTION_DIM = 14 # 7 DOF × 2 arms
ACTION_DIM = 14 # 7 DOF × 2 arms (joint-space control mode)
# End-effector-pose control mode: per arm [x, y, z, qx, qy, qz, qw, gripper] = 8, dual-arm = 16.
# Used by world-model policies (e.g. LingBot-VA) that predict eef-pose deltas executed via CuRobo IK.
EEF_ACTION_DIM = 16
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
DEFAULT_EPISODE_LENGTH = 300
DEFAULT_EPISODE_LENGTH = 1200
OFFICIAL_INSTRUCTION_ENV = "LEROBOT_ROBOTWIN_OFFICIAL_INSTRUCTION"
OFFICIAL_INSTRUCTION_TYPE_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_TYPE"
OFFICIAL_INSTRUCTION_MAX_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_MAX"
def _compose_eef_pose(new_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
"""Compose a single-arm predicted delta pose onto the initial pose.
``new_pose`` / ``init_pose`` are 8-vectors ``[x, y, z, qx, qy, qz, qw, gripper]``. Translation
is added, rotation is composed (``init_R * new_R``), and the gripper is taken from the
prediction. Mirrors ``add_eef_pose`` in the upstream LingBot-VA RoboTwin client.
"""
from scipy.spatial.transform import Rotation
new_r = Rotation.from_quat(new_pose[3:7])
init_r = Rotation.from_quat(init_pose[3:7])
out_rot = (init_r * new_r).as_quat()
out_trans = new_pose[:3] + init_pose[:3]
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
def _add_init_eef_pose(delta_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
"""Compose a dual-arm (16-d) predicted delta pose onto the initial eef pose, normalizing quats."""
left = _compose_eef_pose(delta_pose[:8], init_pose[:8])
right = _compose_eef_pose(delta_pose[8:], init_pose[8:])
out = np.concatenate([left, right])
# Normalize the two quaternions (indices 3:7 and 11:15) as the upstream client does.
out[3:7] = out[3:7] / (np.linalg.norm(out[3:7]) + 1e-8)
out[11:15] = out[11:15] / (np.linalg.norm(out[11:15]) + 1e-8)
return out
def _env_flag(name: str, default: bool = False) -> bool:
raw = os.environ.get(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def _arm_for_block(block: Any) -> str:
return "left" if float(block.get_pose().p[0]) < 0 else "right"
def _robotwin_blocks_episode_info(task_name: str, env: Any) -> dict[str, str] | None:
"""Infer the episode-info dict used by RoboTwin's official instruction generator for block ranking."""
if task_name == "blocks_ranking_rgb":
return {
"{A}": "red block",
"{B}": "green block",
"{C}": "blue block",
"{a}": _arm_for_block(env.block1),
"{b}": _arm_for_block(env.block2),
"{c}": _arm_for_block(env.block3),
}
if task_name == "blocks_ranking_size":
return {
"{A}": "large block",
"{B}": "medium block",
"{C}": "small block",
"{a}": _arm_for_block(env.block1),
"{b}": _arm_for_block(env.block2),
"{c}": _arm_for_block(env.block3),
}
return None
def _generate_robotwin_official_instruction(task_name: str, env: Any) -> str:
"""Generate language with RoboTwin's official task templates, matching its eval client."""
fallback = task_name.replace("_", " ")
episode_info = _robotwin_blocks_episode_info(task_name, env)
if episode_info is None:
logger.warning("Official RoboTwin instruction is not implemented for task=%s; using %r.", task_name, fallback)
return fallback
try:
from description.utils.generate_episode_instructions import generate_episode_descriptions
except Exception:
logger.warning("Failed to import RoboTwin official instruction generator; using %r.", fallback, exc_info=True)
return fallback
instruction_type = os.environ.get(OFFICIAL_INSTRUCTION_TYPE_ENV, "seen")
try:
max_descriptions = int(os.environ.get(OFFICIAL_INSTRUCTION_MAX_ENV, "1000000"))
except ValueError:
max_descriptions = 1000000
results = generate_episode_descriptions(task_name, [episode_info], max_descriptions=max_descriptions)
if not results:
logger.warning("RoboTwin generated no official instructions for task=%s; using %r.", task_name, fallback)
return fallback
options = results[0].get(instruction_type) or results[0].get("seen") or results[0].get("unseen")
if not options:
logger.warning(
"RoboTwin generated no %s official instructions for task=%s; using %r.",
instruction_type,
task_name,
fallback,
)
return fallback
return str(np.random.choice(options))
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
DEFAULT_CAMERA_H = 240
DEFAULT_CAMERA_W = 320
@@ -234,6 +342,7 @@ class RoboTwinEnv(gym.Env):
observation_width: int | None = None,
episode_length: int = DEFAULT_EPISODE_LENGTH,
render_mode: str = "rgb_array",
action_mode: str = "joint",
):
super().__init__()
self.task_name = task_name
@@ -241,6 +350,13 @@ class RoboTwinEnv(gym.Env):
self.task_description = task_name.replace("_", " ")
self.episode_index = episode_index
self._reset_stride = n_envs
# "joint": 14-d joint-space actions via take_action(action). "ee": 16-d end-effector-pose
# deltas (added onto the episode's initial eef pose) executed via take_action(.., "ee") + IK.
if action_mode not in ("joint", "ee"):
raise ValueError(f"action_mode must be 'joint' or 'ee'; got {action_mode!r}")
self.action_mode = action_mode
self._action_dim = EEF_ACTION_DIM if action_mode == "ee" else ACTION_DIM
self._init_eef_pose: np.ndarray | None = None
self.camera_names = list(camera_names)
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
# The YAML-driven lookup is deferred to reset() so construction doesn't
@@ -271,7 +387,7 @@ class RoboTwinEnv(gym.Env):
}
)
self.action_space = spaces.Box(
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
low=ACTION_LOW, high=ACTION_HIGH, shape=(self._action_dim,), dtype=np.float32
)
def _ensure_env(self) -> None:
@@ -317,6 +433,18 @@ class RoboTwinEnv(gym.Env):
return {"pixels": images, "agent_pos": joint_state}
def _read_eef_pose(self) -> np.ndarray:
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
assert self._env is not None, "_read_eef_pose called before _ensure_env()"
ep = self._env.get_obs()["endpose"]
pose = (
list(ep["left_endpose"])
+ [ep["left_gripper"]]
+ list(ep["right_endpose"])
+ [ep["right_gripper"]]
)
return np.asarray(pose, dtype=np.float64)
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
self._ensure_env()
super().reset(seed=seed)
@@ -330,16 +458,32 @@ class RoboTwinEnv(gym.Env):
self.episode_index += self._reset_stride
self._step_count = 0
use_official_instruction = self.task_name in {"blocks_ranking_rgb", "blocks_ranking_size"}
if _env_flag(OFFICIAL_INSTRUCTION_ENV, default=use_official_instruction):
self.task_description = _generate_robotwin_official_instruction(self.task_name, self._env)
if hasattr(self._env, "set_instruction"):
self._env.set_instruction(instruction=self.task_description)
logger.info("RoboTwin official instruction | task=%s | %s", self.task_name, self.task_description)
else:
self.task_description = self.task_name.replace("_", " ")
# In eef mode the policy predicts pose deltas relative to the initial eef pose.
if self.action_mode == "ee":
self._init_eef_pose = self._read_eef_pose()
obs = self._get_obs()
return obs, {"is_success": False, "task": self.task_name}
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
assert self._env is not None, "step() called before reset()"
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
if action.ndim != 1 or action.shape[0] != self._action_dim:
raise ValueError(f"Expected 1-D action of shape ({self._action_dim},), got {action.shape}")
with torch.enable_grad():
if hasattr(self._env, "take_action"):
if self.action_mode == "ee":
ee_action = _add_init_eef_pose(np.asarray(action, dtype=np.float64), self._init_eef_pose)
self._env.take_action(ee_action, action_type="ee")
elif hasattr(self._env, "take_action"):
self._env.take_action(action)
else:
self._env.step(action)
@@ -398,6 +542,7 @@ def _make_env_fns(
observation_height: int,
observation_width: int,
episode_length: int,
action_mode: str = "joint",
) -> list[Callable[[], RoboTwinEnv]]:
"""Return n_envs factory callables for a single task."""
@@ -410,6 +555,7 @@ def _make_env_fns(
observation_height=observation_height,
observation_width=observation_width,
episode_length=episode_length,
action_mode=action_mode,
)
return [partial(_make_one, i) for i in range(n_envs)]
@@ -423,6 +569,7 @@ def create_robotwin_envs(
observation_height: int = DEFAULT_CAMERA_H,
observation_width: int = DEFAULT_CAMERA_W,
episode_length: int = DEFAULT_EPISODE_LENGTH,
action_mode: str = "joint",
) -> dict[str, dict[int, Any]]:
"""Create vectorized RoboTwin 2.0 environments.
@@ -473,6 +620,7 @@ def create_robotwin_envs(
observation_height=observation_height,
observation_width=observation_width,
episode_length=episode_length,
action_mode=action_mode,
)
if is_async:
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
+17 -3
View File
@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import _placo_available, require_package
from lerobot.utils.import_utils import require_package
if TYPE_CHECKING or _placo_available:
_placo_runtime_error: ImportError | None = None
if TYPE_CHECKING:
import placo # type: ignore[import-not-found]
else:
placo = None
try:
import placo # type: ignore[import-not-found]
except ImportError as _placo_import_err:
placo = None
_placo_runtime_error = _placo_import_err
def _raise_if_placo_unusable() -> None:
if placo is None and _placo_runtime_error is not None:
raise ImportError(
f"placo is installed but failed to import: {_placo_runtime_error!s}"
) from _placo_runtime_error
class RobotKinematics:
@@ -44,6 +57,7 @@ class RobotKinematics:
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
require_package("placo", extra="placo-dep")
_raise_if_placo_unusable()
self.robot = placo.RobotWrapper(urdf_path)
self.solver = placo.KinematicsSolver(self.robot)
+100 -18
View File
@@ -43,6 +43,7 @@ from .tables import (
CAN_CMD_SET_ZERO,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
HANDSHAKE_TIMEOUT_S,
MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA,
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
def _query_status_via_clear_fault(
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
) -> tuple[bool, can.Message | None]:
motor_name = self._get_motor_name(motor)
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
def _recv_status_via_clear_fault(
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
faulted_motors = []
for motor_name in self.motors:
has_fault, msg = self._query_status_via_clear_fault(motor_name)
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
if msg is None:
missing_motors.append(motor_name)
elif has_fault:
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
return responses
def _recv_all_messages_until_quiet(
self,
*,
timeout: float = RUNNING_TIMEOUT,
max_messages: int = 4096,
) -> list[can.Message]:
"""
Receive frames until the bus goes quiet.
Args:
timeout: Poll timeout used for each recv() call. Collection stops
when one recv() times out (quiet gap).
max_messages: Safety cap to prevent unbounded loops.
"""
out: list[can.Message] = []
max_messages = max(1, max_messages)
timeout = max(0.0, timeout)
try:
while len(out) < max_messages:
msg = self._bus().recv(timeout=timeout)
if msg is None:
break
out.append(msg)
except (can.CanError, OSError) as e:
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
return out
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
"""
Decode all received feedback frames and update cached motor states.
Returns:
Set of payload recv_ids that were successfully mapped to motors.
"""
processed_recv_ids: set[int] = set()
for msg in messages:
if len(msg.data) < 1:
logger.debug(
f"Dropping short CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
)
continue
recv_id = int(msg.data[0])
motor_name = self._recv_id_to_motor.get(recv_id)
if motor_name is None:
logger.debug(
f"Unmapped CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
)
continue
self._process_response(motor_name, msg)
processed_recv_ids.add(recv_id)
return processed_recv_ids
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
"""
Drain pending RX frames from the CAN interface.
This is used by higher-level controllers to drop stale feedback before issuing
a fresh read cycle, so subsequent state reads are based on most recent replies.
It should also be called once when a controller instance is created/connected,
to clear residual frames left on the interface from previous sessions.
"""
drained = 0
poll_timeout_s = max(0.0, poll_timeout_s)
max_messages = max(1, max_messages)
try:
while drained < max_messages:
msg = self._bus().recv(timeout=poll_timeout_s)
if msg is None:
break
drained += 1
except (can.CanError, OSError) as e:
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
return drained
def _speed_control(
self,
motor: NameOrID,
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Read every feedback frame until RX goes quiet, then decode all of them.
# This avoids dropping useful frames when responses from different motors interleave.
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
try:
self._decode_motor_state(msg.data)
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
logger.warning(
f"Failed to decode response from {motor} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
)
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the state cache."""
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
self._bus().send(msg)
updated_motors.append(motor)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
for response in responses.values():
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
if payload_motor_name is not None:
self._process_response(payload_motor_name, response)
else:
# Fallback: still attempt to decode based on payload byte0 mapping.
self._decode_motor_state(response.data)
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
for motor in updated_motors:
recv_id = self._get_motor_recv_id(motor)
if recv_id not in responses:
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def read_calibration(self) -> dict[str, MotorCalibration]:
+2 -1
View File
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.001
RUNNING_TIMEOUT = 0.003
HANDSHAKE_TIMEOUT_S = 0.05
PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02
+22
View File
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("constant_with_warmup")
@dataclass
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Linear warmup followed by a constant learning rate.
Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``):
the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat.
"""
num_warmup_steps: int = 1000
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
warmup_steps = self.num_warmup_steps or 0
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
+4
View File
@@ -20,6 +20,8 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
@@ -43,6 +45,8 @@ __all__ = [
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"LingBotVAConfig",
"MolmoAct2Config",
"MultiTaskDiTConfig",
"PI0Config",
"PI0FastConfig",
+55 -2
View File
@@ -49,6 +49,8 @@ from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
@@ -56,6 +58,7 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -88,7 +91,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
"molmoact2".
Returns:
The policy class corresponding to the given name.
@@ -151,6 +155,18 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "molmoact2":
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
return MolmoAct2Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
elif name == "lingbot_va":
from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy
return LingBotVAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -168,7 +184,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x".
"smolvla", "wall_x", "molmoact2".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -203,6 +219,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "molmoact2":
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
elif policy_type == "lingbot_va":
return LingBotVAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -231,6 +253,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
dataset_meta: Any | None
def make_pre_post_processors(
@@ -406,6 +429,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -414,6 +438,31 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MolmoAct2Config):
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
processors = make_molmoact2_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, LingBotVAConfig):
from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
processors = make_lingbot_va_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
@@ -499,6 +548,10 @@ def make_policy(
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
if ds_meta is not None:
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
if callable(set_dataset_feature_metadata):
set_dataset_feature_metadata(ds_meta.features)
kwargs["config"] = cfg
@@ -60,6 +60,7 @@ class Eagle25VLPreTrainedModel(PreTrainedModel):
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True
@@ -124,7 +124,6 @@ class Eagle25VLProcessor(ProcessorMixin):
"videos_kwargs",
"text_kwargs",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
+15 -9
View File
@@ -14,7 +14,7 @@
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict
action_head_cfg: dict
action_horizon: int
action_dim: int
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
compute_dtype: str = "float32"
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
# real model
@@ -206,7 +206,11 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS
"Vendor files are copied during model creation. Create the policy/model first, "
"or call ensure_eagle_cache_ready() before building processors."
)
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
proc = AutoProcessor.from_pretrained(
str(cache_dir),
trust_remote_code=True,
fix_mistral_regex=False,
)
proc.tokenizer.padding_side = "left"
return proc
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/lingbot_va.mdx
@@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a
# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin).
# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the
# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when
# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py).
from .configuration_lingbot_va import LingBotVAConfig
from .processor_lingbot_va import make_lingbot_va_pre_post_processors
__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"]
def __getattr__(name):
if name == "LingBotVAPolicy":
from .modeling_lingbot_va import LingBotVAPolicy
return LingBotVAPolicy
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,168 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for the LingBot-VA policy.
LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2
video-diffusion stack. It interleaves prediction of future video latents and robot
actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the
upstream repository (https://github.com/Robbyant/lingbot-va).
Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``)
and the ``transformer/config.json`` of the released checkpoints.
"""
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION
@PreTrainedConfig.register_subclass("lingbot_va")
@dataclass
class LingBotVAConfig(PreTrainedConfig):
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
# Wan transformer architecture
patch_size: tuple[int, int, int] = (1, 2, 2)
num_attention_heads: int = 24
attention_head_dim: int = 128
in_channels: int = 48
out_channels: int = 48
action_dim: int = 30
text_dim: int = 4096
freq_dim: int = 256
ffn_dim: int = 14336
num_layers: int = 30
cross_attn_norm: bool = True
eps: float = 1e-6
rope_max_seq_len: int = 1024
# "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn".
attn_mode: str = "torch"
# Frozen sub-models (VAE + UMT5 text encoder + tokenizer)
# ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo /
# local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders).
wan_pretrained_path: str = "robbyant/lingbot-va-base"
dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32"
# Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode).
text_encoder_device: str = "cpu"
# Observation cameras (order matters: latents are concatenated on width; LIBERO defaults)
obs_cam_keys: list[str] = field(
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
)
# Undo the LIBERO env processor's extra horizontal flip to match the model's training orientation.
image_hflip: bool = False
# Camera latent layout: "width_concat" (cameras concatenated on width; LIBERO) or
# "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin).
camera_layout: str = "width_concat"
# Inference hyperparameters (LIBERO defaults)
n_obs_steps: int = 1
height: int = 128
width: int = 128
action_per_frame: int = 4
frame_chunk_size: int = 4
attn_window: int = 30
num_inference_steps: int = 20
video_exec_step: int = -1
action_num_inference_steps: int = 50
guidance_scale: float = 5.0
action_guidance_scale: float = 1.0
snr_shift: float = 5.0
action_snr_shift: float = 0.05
max_sequence_length: int = 512 # UMT5 prompt length
# Subset of the 30-d action space used by the benchmark (LIBERO = 7-DoF). The action
# (un)normalization quantiles live in the checkpoint's ``policy_postprocessor.json``, not here.
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
# Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s.
save_predicted_video: bool = False
# Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
# quantile-(un)normalized inside the policy / dedicated processor steps.
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
# Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py)
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-4
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 1000
def __post_init__(self):
super().__post_init__()
if self.attn_mode not in ("torch", "flashattn", "flex"):
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
@property
def chunk_size(self) -> int:
"""Number of single-step actions produced per autoregressive chunk."""
return self.frame_chunk_size * self.action_per_frame
@property
def n_action_steps(self) -> int:
"""Number of actions executed before refilling (the whole chunk)."""
return self.chunk_size
def validate_features(self) -> None:
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"LingBot-VA requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if ACTION not in self.output_features:
self.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),)
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
# Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda).
from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig
return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,87 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre/post-processor pipelines for the LingBot-VA policy.
The preprocessor passes inputs through (IDENTITY) and the postprocessor maps the policy's
``[-1, 1]`` actions back to physical units with the built-in ``UnnormalizerProcessorStep``
(QUANTILES) using per-channel q01/q99 restored from the checkpoint.
"""
from typing import Any
import torch
from lerobot.configs.types import FeatureType, NormalizationMode
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_lingbot_va import LingBotVAConfig
def make_lingbot_va_pre_post_processors(
config: LingBotVAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build the pre/post processor pipelines for LingBot-VA."""
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
# Unnormalize actions from [-1, 1] to physical units (QUANTILES) using q01/q99 restored from the checkpoint.
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map={FeatureType.ACTION: NormalizationMode.QUANTILES},
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_molmoact2_README.md
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_molmoact2 import MolmoAct2Config
from .modeling_molmoact2 import MolmoAct2Policy
from .processor_molmoact2 import make_molmoact2_pre_post_processors
__all__ = ["MolmoAct2Config", "MolmoAct2Policy", "make_molmoact2_pre_post_processors"]
@@ -0,0 +1,519 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import math
import os
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import (
AdamWConfig,
CosineDecayWithWarmupSchedulerConfig,
LRSchedulerConfig,
OptimizerConfig,
)
from lerobot.utils.constants import ACTION, OBS_STATE
from ..rtc.configuration_rtc import RTCConfig
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
@dataclass
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
training steps"; build() is the first point where num_training_steps is known.
"""
num_decay_steps: int | None
def build(self, optimizer, num_training_steps: int):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.peak_lr,
decay_lr=self.decay_lr,
num_warmup_steps=self.num_warmup_steps,
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
).build(optimizer, num_training_steps=num_training_steps)
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
@PreTrainedConfig.register_subclass("molmoact2")
@dataclass
class MolmoAct2Config(PreTrainedConfig):
"""MolmoAct2 policy backed by the converted HF checkpoint implementation."""
checkpoint_path: str = "allenai/MolmoAct2"
checkpoint_revision: str | None = None
checkpoint_force_download: bool = False
n_obs_steps: int = 1
chunk_size: int = 30
n_action_steps: int = 30
action_mode: str = "both"
inference_action_mode: str | None = None
discrete_action_tokenizer: str = "allenai/MolmoAct2-FAST-Tokenizer"
discrete_generation_max_steps: int | None = None
norm_tag: str | None = None
setup_type: str = ""
control_mode: str = ""
image_keys: list[str] = field(default_factory=list)
normalize_language: bool = True
add_setup_tokens: bool = True
add_control_tokens: bool = True
normalize_gripper: bool = False
num_state_tokens: int = 256
# Leave unset for the default MolmoAct2 sequence budget inferred from the fixed
# image/prompt/state/action token layout. Override only for unusual long prompts.
max_sequence_length: int | None = None
# Fixed by released MolmoAct2 checkpoints. We validate this at model load.
expected_max_action_dim: int = 32
# Flow-matching training knobs copied from the original MolmoAct2 training path.
num_flow_timesteps: int = 8
flow_matching_cutoff: float = 1.0
flow_matching_time_offset: float = 0.001
flow_matching_time_scale: float = 0.999
flow_matching_beta_alpha: float = 1.0
flow_matching_beta_beta: float = 1.5
num_inference_steps: int | None = None
mask_action_dim_padding: bool = True
enable_inference_cuda_graph: bool = True
# MolmoAct2-local eval option. When enabled, stochastic continuous action
# generation uses a rollout-local generator derived from eval_seed.
per_episode_seed: bool = False
eval_seed: int | None = None
rtc_config: RTCConfig | None = None
# Default is full finetuning with gradients from the action expert flowing into the VLM.
enable_lora_vlm: bool = False
lora_rank: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_bias: str = "none"
enable_lora_action_expert: bool = False
enable_knowledge_insulation: bool = False
freeze_embedding: bool = True
train_action_expert_only: bool = False
gradient_checkpointing: bool = False
model_dtype: str = "bfloat16"
softmax_auxiliary_loss: bool = True
softmax_auxiliary_loss_scale: float = 1e-4
discrete_loss_token_weighting: str = "root_subsegments_root_tokens"
optimizer_lr: float = 1e-5
optimizer_vit_lr: float = 5e-6
optimizer_connector_lr: float = 5e-6
optimizer_action_expert_lr: float = 5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-6
optimizer_weight_decay: float = 0.0
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 200
scheduler_decay_steps: int | None = None
scheduler_decay_lr: float = 1e-6
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES,
"ACTION": NormalizationMode.QUANTILES,
}
)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
dataset_feature_names: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
super().__post_init__()
if self.action_mode not in {"continuous", "discrete", "both"}:
raise ValueError(
f"Unsupported action_mode={self.action_mode!r}. "
"Expected one of {'continuous', 'discrete', 'both'}."
)
if self.inference_action_mode not in {None, "continuous", "discrete"}:
raise ValueError(
f"Unsupported inference_action_mode={self.inference_action_mode!r}. "
"Expected one of {None, 'continuous', 'discrete'}."
)
if self.inference_action_mode == "continuous" and self.action_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' cannot run continuous inference.")
if self.inference_action_mode == "discrete" and self.action_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' cannot run discrete inference.")
if self.train_action_expert_only and self.action_mode != "continuous":
raise ValueError("MolmoAct2 train_action_expert_only requires action_mode='continuous'.")
if self.train_action_expert_only and self.enable_lora_vlm:
raise ValueError("MolmoAct2 train_action_expert_only is incompatible with enable_lora_vlm.")
if self.enable_lora_action_expert and not self.enable_lora_vlm:
raise ValueError("MolmoAct2 enable_lora_action_expert requires enable_lora_vlm.")
if self.chunk_size < 1:
raise ValueError(f"chunk_size must be >= 1, got {self.chunk_size}.")
if self.n_action_steps < 1:
raise ValueError(f"n_action_steps must be >= 1, got {self.n_action_steps}.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})."
)
if self.expected_max_action_dim != 32:
raise ValueError("MolmoAct2 released checkpoints use expected_max_action_dim=32.")
if self.model_dtype not in {"float32", "bfloat16", "float16"}:
raise ValueError(
f"Unsupported model_dtype={self.model_dtype!r}. Expected 'float32', 'bfloat16', or 'float16'."
)
if self.lora_rank < 1:
raise ValueError(f"lora_rank must be >= 1, got {self.lora_rank}.")
if self.lora_alpha < 1:
raise ValueError(f"lora_alpha must be >= 1, got {self.lora_alpha}.")
if not 0 <= self.lora_dropout <= 1:
raise ValueError(f"lora_dropout must be in [0, 1], got {self.lora_dropout}.")
if self.lora_bias not in {"none", "all", "lora_only"}:
raise ValueError(
f"Unsupported lora_bias={self.lora_bias!r}. Expected one of 'none', 'all', or 'lora_only'."
)
if self.discrete_loss_token_weighting not in {
"none",
"token",
"root_tokens",
"root_subsegments",
"root_subsegments_root_tokens",
}:
raise ValueError(
f"Unsupported discrete_loss_token_weighting={self.discrete_loss_token_weighting!r}."
)
if self.discrete_generation_max_steps is not None and self.discrete_generation_max_steps < 1:
raise ValueError(
f"discrete_generation_max_steps must be >= 1 or None, got {self.discrete_generation_max_steps}."
)
if self.max_sequence_length is not None and self.max_sequence_length < 1:
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
def inferred_max_sequence_length(
self,
*,
num_images: int | None = None,
state_dim: int | None = None,
action_dim: int | None = None,
action_horizon: int | None = None,
include_discrete_action: bool | None = None,
) -> int:
if self.max_sequence_length is not None:
return int(self.max_sequence_length)
if num_images is None:
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim is None:
state_feature = self.robot_state_feature
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
if action_dim is None:
action_feature = self.action_feature
action_dim = (
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
)
if action_horizon is None:
action_horizon = self.chunk_size
if include_discrete_action is None:
include_discrete_action = self.action_mode in {"discrete", "both"}
return infer_molmoact2_max_sequence_length(
num_images=int(num_images),
state_dim=int(state_dim),
action_dim=int(action_dim),
action_horizon=int(action_horizon),
include_discrete_action=bool(include_discrete_action),
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
def set_dataset_feature_metadata(self, features: dict[str, Any]) -> None:
self.dataset_feature_names = {}
for key in (ACTION, OBS_STATE):
feature = features.get(key) if isinstance(features, dict) else None
if isinstance(feature, dict) and feature.get("names") is not None:
self.dataset_feature_names[key] = feature["names"]
def validate_features(self) -> None:
"""Validate and set up MolmoAct2 input and output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"MolmoAct2 policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if OBS_STATE not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(0,),
)
self.input_features[OBS_STATE] = state_feature
if ACTION not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.expected_max_action_dim,),
)
self.output_features[ACTION] = action_feature
def apply_norm_tag_metadata(self) -> None:
if not str(self.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.checkpoint_path,
revision=self.checkpoint_revision,
force_download=bool(self.checkpoint_force_download),
norm_tag=self.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.n_action_steps = int(metadata["n_action_steps"])
if not self.setup_type and metadata.get("setup_type") is not None:
self.setup_type = str(metadata["setup_type"])
if not self.control_mode and metadata.get("control_mode") is not None:
self.control_mode = str(metadata["control_mode"])
def saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or self.action_mode
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
requested_mode = self.inference_action_mode
if requested_mode is None:
return
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def validate_checkpoint_action_mode(
self,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
if self.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def resolve_inference_action_mode(
self,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode is None:
requested_mode = self.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
@@ -0,0 +1,17 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
@@ -0,0 +1,237 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
import logging
import os
from pathlib import Path
from typing import ClassVar
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_tokenizer_location(
tokenizer_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
local_path = Path(str(tokenizer_path)).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=str(tokenizer_path),
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
class UniversalActionProcessor(ProcessorMixin):
attributes: ClassVar[list[str]] = ["tokenizer"]
tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
scale: float = 10,
vocab_size: int = 1024,
min_token: int = 0,
*,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.scale = scale
self.vocab_size = vocab_size
self.min_token = min_token
# Action horizon and dimension needed during decoding. These can be specified
# in three ways (in order of priority):
# 1. passed in as kwargs to decode()
# 2. in the constructor
# 3. cached from the last time decode() was called
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
super().__init__(tokenizer)
self.bpe_tokenizer = self.tokenizer
def __call__(self, action_chunk: np.array) -> np.array:
from scipy.fft import dct
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
# Cache the time horizon and action dimension for decoding
self.called_time_horizon = action_chunk.shape[-2]
self.called_action_dim = action_chunk.shape[-1]
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
dct_coeff = np.around(dct_coeff * self.scale)
tokens = []
for elem in dct_coeff:
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
from scipy.fft import idct
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
)
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@classmethod
def fit(
cls,
action_data: list[np.array],
scale: float = 10,
vocab_size: int = 1024,
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "UniversalActionProcessor":
from scipy.fft import dct
# Run DCT over all inputs
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
# Quantize and find min token
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
assert min_vocab_size <= vocab_size, (
f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
)
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
f"size {vocab_size}, consider increasing vocab size"
)
# Make token iterator for BPE training
def _token_iter():
for tokens in dct_tokens:
rounded_tokens = np.around(tokens * scale) - min_token
rounded_tokens = rounded_tokens.astype(int)
string = "".join(map(chr, rounded_tokens))
yield string
# Train BPE tokenizer
bpe = ByteLevelBPETokenizer()
# Set up the entire range of possible tokens as the initial alphabet
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
# because it doesn't support custom alphabets)
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
return cls(
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)
@classmethod
def from_pretrained_local(
cls,
pretrained_model_name_or_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> "UniversalActionProcessor":
location = Path(
_resolve_tokenizer_location(
pretrained_model_name_or_path,
revision=revision,
force_download=force_download,
)
)
processor_config = {}
processor_config_path = location / "processor_config.json"
if processor_config_path.exists():
import json
processor_config = json.loads(processor_config_path.read_text())
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(location))
return cls(
tokenizer,
scale=processor_config.get("scale", 10),
vocab_size=processor_config.get("vocab_size", 1024),
min_token=processor_config.get("min_token", 0),
action_dim=processor_config.get("action_dim"),
time_horizon=processor_config.get("time_horizon"),
)
@@ -0,0 +1,553 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
MolmoAct2 configuration
"""
from typing import Optional, Any
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MolmoAct2VitConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
>>> # Initializing a MolmoAct2VitConfig
>>> configuration = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2VisionTransformer (with random weights)
>>> model = MolmoAct2VisionTransformer(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
base_config_key = "vit_config"
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
num_hidden_layers: int = 27,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
hidden_act: str = "gelu_pytorch_tanh",
layer_norm_eps: float = 1e-6,
image_default_input_size: tuple[int, int] = (378, 378),
image_patch_size: int = 14,
image_num_pos: int = 577,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
initializer_range: float = 0.02,
float32_attention: bool = True,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.image_default_input_size = image_default_input_size
self.image_patch_size = image_patch_size
self.image_num_pos = image_num_pos
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.initializer_range = initializer_range
self.float32_attention = float32_attention
@property
def image_num_patch(self):
h, w = self.image_default_input_size
return h // self.image_patch_size, w // self.image_patch_size
class MolmoAct2AdapterConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
>>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
>>> vit_config = MolmoAct2VitConfig()
>>> adapter_config = MolmoPoolingConfig()
>>> # Initializing a MolmoAct2VisionBackbone (with random weights)
>>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
>>> # Accessing the model configuration
>>> vit_configuration = model.vit_config
>>> adapter_configuration = model.adapter_config
```"""
model_type = "molmoact2"
base_config_key = "adapter_config"
def __init__(
self,
vit_layers: tuple = (-3, -9),
pooling_attention_mask: bool = False,
hidden_size: int = 1152,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 72,
float32_attention: bool = True,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
hidden_act: str = "silu",
intermediate_size: int = 18944,
text_hidden_size: int = 3584,
image_feature_dropout: float = 0.0,
initializer_range: float = 0.02,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.vit_layers = vit_layers
self.pooling_attention_mask = pooling_attention_mask
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.float32_attention = float32_attention
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.text_hidden_size = text_hidden_size
self.image_feature_dropout = image_feature_dropout
self.initializer_range = initializer_range
class MolmoAct2TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
`MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Example:
```python
>>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
>>> # Initializing a MolmoAct2TextConfig
>>> configuration = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2TextModel (with random weights)
>>> model = MolmoAct2TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"blocks.*.self_attn.att_proj": "colwise",
"blocks.*.self_attn.attn_out": "rowwise",
"blocks.*.mlp.ff_proj": "colwise",
"blocks.*.mlp.ff_out": "rowwise",
}
base_model_pp_plan = {
"wte": (["input_ids"], ["inputs_embeds"]),
"blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
"ln_f": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
hidden_size: int = 3584,
num_attention_heads: int = 28,
num_key_value_heads: int | None = 4,
head_dim: int = 128,
vocab_size: int = 152064,
additional_vocab_size: int = 128,
qkv_bias: bool = True,
num_hidden_layers: int = 48,
intermediate_size: int = 18944,
hidden_act: str = "silu",
embedding_dropout: float = 0.0,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
max_position_embeddings: int = 4096,
rope_theta: float = 1000000.0,
rope_scaling: dict[str, Any] = None,
rope_scaling_layers: list[int] | None = None,
use_qk_norm: bool = False,
qk_norm_type: str = "olmo",
layer_norm_eps: int = 1e-6,
norm_after: bool = False,
initializer_range: float = 0.02,
use_cache=True,
tie_word_embeddings=False,
attn_implementation: str = "eager",
**kwargs,
):
self.attn_implementation = attn_implementation
super().__init__(
tie_word_embeddings=tie_word_embeddings, attn_implementation=attn_implementation, **kwargs
)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.vocab_size = vocab_size
self.additional_vocab_size = additional_vocab_size
self.qkv_bias = qkv_bias
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.embedding_dropout = embedding_dropout
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.rope_scaling_layers = rope_scaling_layers
self.use_qk_norm = use_qk_norm
self.qk_norm_type = qk_norm_type
self.layer_norm_eps = layer_norm_eps
self.norm_after = norm_after
self.initializer_range = initializer_range
self.use_cache = use_cache
# Validate the correctness of rotary position embeddings parameters
rope_config_validation(self)
class MolmoAct2ActionExpertConfig(PretrainedConfig):
r"""Configuration for the MolmoAct2 modern action expert."""
model_type = "molmoact2_action_expert"
base_config_key = "action_expert_config"
def __init__(
self,
max_action_horizon: int = 32,
max_action_dim: int = 32,
hidden_size: int = 1024,
num_layers: int = 32,
num_heads: int = 16,
mlp_ratio: float = 8.0 / 3.0,
ffn_multiple_of: int = 256,
timestep_embed_dim: int = 256,
dropout: float = 0.0,
attn_dropout: float = 0.0,
context_layer_norm: bool = True,
qk_norm: bool = True,
qk_norm_eps: float = 1e-6,
rope: bool = True,
causal_attn: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.max_action_horizon = max_action_horizon
self.max_action_dim = max_action_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.ffn_multiple_of = ffn_multiple_of
self.timestep_embed_dim = timestep_embed_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.context_layer_norm = context_layer_norm
self.qk_norm = qk_norm
self.qk_norm_eps = qk_norm_eps
self.rope = rope
self.causal_attn = causal_attn
def to_dict(self):
output = super().to_dict()
# These are derived from the parent MolmoAct2Config for HF exports. Keeping
# them out of the public nested config avoids duplicated sources of truth.
output.pop("max_action_horizon", None)
output.pop("max_action_dim", None)
return output
class MolmoAct2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
Example:
```python
>>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
>>> # Initializing a MolmoAct2VitConfig
>>> vit_config = MolmoAct2VitConfig()
>>> # Initializing a MolmoAct2AdapterConfig
>>> adapter_config = MolmoAct2AdapterConfig()
>>> # Initializing a MolmoAct2TextConfig
>>> text_config = MolmoAct2TextConfig()
>>> # Initializing a MolmoAct2Config
>>> configuration = MolmoAct2Config(
>>> vit_config=vit_config,
>>> adapter_config=adapter_config,
>>> text_config=text_config,
>>> image_start_token_id=151936,
>>> image_end_token_id=151937,
>>> image_patch_id=151938,
>>> image_col_id=151939,
>>> low_res_image_start_token_id=151940,
>>> image_low_res_id=151942,
>>> frame_start_token_id=151943,
>>> frame_end_token_id=151944,
>>> )
>>> # Initializing a model
>>> model = MolmoAct2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "molmoact2"
sub_configs = {
"text_config": MolmoAct2TextConfig,
"vit_config": MolmoAct2VitConfig,
"adapter_config": MolmoAct2AdapterConfig,
"action_expert_config": MolmoAct2ActionExpertConfig,
}
def __init__(
self,
vit_config: MolmoAct2VitConfig = None,
adapter_config: MolmoAct2AdapterConfig = None,
text_config: MolmoAct2TextConfig = None,
action_expert_config: MolmoAct2ActionExpertConfig = None,
image_start_token_id: int = None,
low_res_image_start_token_id: int = None,
image_end_token_id: int = None,
image_low_res_id: int = None,
image_patch_id: int = None,
image_col_id: int = None,
frame_start_token_id: int = None,
frame_end_token_id: int = None,
use_frame_special_tokens: bool = True,
initializer_range: float = 0.02,
add_action_expert: bool = True,
max_action_dim: int = 32,
max_action_horizon: int = 30,
n_obs_steps: int = 30,
action_mode: str = "both",
state_format: str = "discrete",
flow_matching_num_steps: int = 10,
flow_matching_cutoff: float = 1.0,
flow_matching_time_offset: float = 0.001,
flow_matching_time_scale: float = 0.999,
flow_matching_beta_alpha: float = 1.0,
flow_matching_beta_beta: float = 1.5,
mask_action_dim_padding: bool = True,
enable_depth_reasoning: bool = False,
depth_mode: int = 2,
num_depth_codes: int = 100,
action_expert_depth_gate: bool = False,
action_expert_depth_gate_per_layer: bool = False,
action_expert_depth_gate_init_bias: float = -4.0,
action_output_token_id: int = None,
action_start_token_id: int = None,
action_end_token_id: int = None,
action_token_start_id: int = None,
num_action_tokens: int = 0,
depth_output_token_id: int = None,
depth_start_token_id: int = None,
depth_end_token_id: int = None,
depth_token_start_id: int = None,
num_depth_tokens: int = 0,
state_start_token_id: int = None,
state_end_token_id: int = None,
state_token_start_id: int = None,
num_state_tokens: int = 0,
add_setup_tokens: bool = True,
add_control_tokens: bool = True,
norm_stats_filename: str = "norm_stats.json",
**kwargs,
):
super().__init__(**kwargs)
if vit_config is None:
self.vit_config = MolmoAct2VitConfig()
elif isinstance(vit_config, dict):
self.vit_config = MolmoAct2VitConfig(**vit_config)
else:
self.vit_config = vit_config
if adapter_config is None:
self.adapter_config = MolmoAct2AdapterConfig()
elif isinstance(adapter_config, dict):
self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
else:
self.adapter_config = adapter_config
if text_config is None:
self.text_config = MolmoAct2TextConfig()
elif isinstance(text_config, dict):
self.text_config = MolmoAct2TextConfig(**text_config)
else:
self.text_config = text_config
self.add_action_expert = bool(add_action_expert)
if not self.add_action_expert:
self.action_expert_config = None
elif action_expert_config is None:
self.action_expert_config = MolmoAct2ActionExpertConfig(
max_action_horizon=max_action_horizon,
max_action_dim=max_action_dim,
num_layers=self.text_config.num_hidden_layers,
)
elif isinstance(action_expert_config, dict):
self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
else:
self.action_expert_config = action_expert_config
if self.add_action_expert:
self.action_expert_config.max_action_dim = int(max_action_dim)
self.action_expert_config.max_action_horizon = int(max_action_horizon)
self._validate_release_action_config(
state_format=state_format,
)
self.image_start_token_id = image_start_token_id
self.low_res_image_start_token_id = low_res_image_start_token_id
self.image_end_token_id = image_end_token_id
self.image_low_res_id = image_low_res_id
self.image_high_res_id = image_patch_id
self.image_patch_id = image_patch_id
self.image_col_id = image_col_id
self.frame_start_token_id = frame_start_token_id
self.frame_end_token_id = frame_end_token_id
self.use_frame_special_tokens = use_frame_special_tokens
self.initializer_range = initializer_range
self.max_action_dim = max_action_dim
self.max_action_horizon = max_action_horizon
self.n_obs_steps = n_obs_steps
self.action_mode = action_mode
self.state_format = state_format
self.flow_matching_num_steps = flow_matching_num_steps
self.flow_matching_cutoff = flow_matching_cutoff
self.flow_matching_time_offset = flow_matching_time_offset
self.flow_matching_time_scale = flow_matching_time_scale
self.flow_matching_beta_alpha = flow_matching_beta_alpha
self.flow_matching_beta_beta = flow_matching_beta_beta
self.mask_action_dim_padding = mask_action_dim_padding
self.enable_depth_reasoning = enable_depth_reasoning
self.depth_mode = depth_mode
self.num_depth_codes = num_depth_codes
self.action_expert_depth_gate = action_expert_depth_gate
self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
self.action_output_token_id = action_output_token_id
self.action_start_token_id = action_start_token_id
self.action_end_token_id = action_end_token_id
self.action_token_start_id = action_token_start_id
self.num_action_tokens = num_action_tokens
self.depth_output_token_id = depth_output_token_id
self.depth_start_token_id = depth_start_token_id
self.depth_end_token_id = depth_end_token_id
self.depth_token_start_id = depth_token_start_id
self.num_depth_tokens = num_depth_tokens
self.state_start_token_id = state_start_token_id
self.state_end_token_id = state_end_token_id
self.state_token_start_id = state_token_start_id
self.num_state_tokens = num_state_tokens
self.add_setup_tokens = add_setup_tokens
self.add_control_tokens = add_control_tokens
self.norm_stats_filename = norm_stats_filename
@staticmethod
def _validate_release_action_config(
*,
state_format: str,
) -> None:
if state_format != "discrete":
raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
@property
def image_num_patch(self):
assert self.vit_config is not None
return self.vit_config.image_num_patch
@property
def num_attention_heads(self):
return self.text_config.num_attention_heads
@property
def num_key_value_heads(self):
return self.text_config.num_key_value_heads
@property
def head_dim(self):
return self.text_config.head_dim
@property
def num_hidden_layers(self):
return self.text_config.num_hidden_layers
@property
def hidden_size(self):
return self.text_config.hidden_size
@property
def vocab_size(self):
return self.text_config.vocab_size
@property
def max_position_embeddings(self):
return self.text_config.max_position_embeddings
MolmoAct2VitConfig.register_for_auto_class()
MolmoAct2AdapterConfig.register_for_auto_class()
MolmoAct2TextConfig.register_for_auto_class()
MolmoAct2ActionExpertConfig.register_for_auto_class()
MolmoAct2Config.register_for_auto_class()
@@ -0,0 +1,564 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Image processor class for MolmoAct2"""
from typing import Optional, Union
import numpy as np
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
valid_images,
to_numpy_array,
)
from transformers.image_transforms import convert_to_rgb
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def select_tiling(h, w, patch_size, max_num_crops):
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_crops + 1):
for j in range(1, max_num_crops + 1):
if i * j <= max_num_crops:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
# How much we would need to scale the image to fit exactly in each tiling
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
# The original size can be zero in rare cases if the image is smaller than the margin
# In those cases letting the scale become infinite means the tiling is based on the
# other side, or falls back to the smallest tiling
with np.errstate(divide="ignore"):
required_scale_d = (candidate_resolutions.astype(np.float32) / original_size,)
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
ix = np.argmax(required_scale)
else:
# Pick the resolution that required the least upscaling so that it most closely fits the image
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
ix = np.argmin(required_scale)
return candidate_tilings[ix]
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def build_overlapping_crops(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Decompose an image into a set of overlapping crops
:return crop_arr: [n_crops, h, w, 3] The crops
:return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image
the crops were extracted from, what patch in `crop_arr` it corresponds to
"""
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
assert base_image_input_size[0] == base_image_input_size[1]
left_margin, right_margin = overlap_margins
total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim
crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
crop_window_size = crop_window_patches * image_patch_size
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
# Decide how to tile the image, to account for the overlap margins we compute the tiling
# as if we had an image without the margins and were using a crop size without the margins
tiling = select_tiling(
original_image_h - total_margin_pixels,
original_image_w - total_margin_pixels,
crop_window_size,
max_crops,
)
src = resize_image(
image,
[
tiling[0] * crop_window_size + total_margin_pixels,
tiling[1] * crop_window_size + total_margin_pixels,
],
resample,
)
src = normalize_image(src, image_mean, image_std)
# Now we have to split the image into crops, and track what patches came from
# where in `patch_idx_arr`
n_crops = tiling[0] * tiling[1]
crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype)
patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32)
on_crop = 0
for i in range(tiling[0]):
# Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size`
# which results in overlapping crop windows
y0 = i * crop_window_size
for j in range(tiling[1]):
x0 = j * crop_window_size
crop_arr[on_crop] = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
patch_idx = np.arange(crop_patch_w * crop_patch_h).reshape(crop_patch_h, crop_patch_w)
patch_idx += on_crop * crop_patch_h * crop_patch_w
# Mask out idx that are in the overlap region
if i != 0:
patch_idx[:left_margin, :] = -1
if j != 0:
patch_idx[:, :left_margin] = -1
if i != tiling[0] - 1:
patch_idx[-right_margin:, :] = -1
if j != tiling[1] - 1:
patch_idx[:, -right_margin:] = -1
patch_idx_arr[on_crop] = patch_idx
on_crop += 1
# `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr`
# so it is ordered left-to-right order
patch_idx_arr = np.reshape(patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w])
patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3])
patch_idx_arr = np.reshape(patch_idx_arr, [-1])
# Now get the parts not in the overlap region, so it should map each patch in `src`
# to the correct patch it should come from in `crop_arr`
patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape(
src.shape[0] // image_patch_size,
src.shape[1] // image_patch_size,
)
return crop_arr, patch_idx_arr
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: np.ndarray,
max_crops: int,
overlap_margins: list[int],
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
crop_mode: str = "overlap-and-resize-c2",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each (low-res, high-res) image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
base_image_input_d = image_patch_size
pooling_w = image_pooling_w
pooling_h = image_pooling_h
crop_patch_w = base_image_input_size[1] // base_image_input_d
crop_patch_h = base_image_input_size[0] // base_image_input_d
if crop_mode == "resize":
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [np.array([resized_h, resized_w, 0, 0])]
return (
np.stack(image_grid, 0),
batch_pixels_to_patches(resized, image_patch_size),
resize_idx,
)
if crop_mode not in {"overlap-and-resize-c2", "overlap-and-resize"}:
raise ValueError(f"Unsupported MolmoAct2 image crop_mode {crop_mode!r}.")
crop_arr, patch_idx_arr = build_overlapping_crops(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
# Finally do the same for the global image
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
crop_arr = np.concatenate([resized, crop_arr], 0)
resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
resized_h, resized_w = resize_idx.shape[:2]
resize_idx = resize_idx.reshape([-1, pooling_h * pooling_w])
# Global image goes first, so the order of patches in previous crops gets increased
pooling_idx = np.where(pooling_idx >= 0, pooling_idx + crop_patch_h * crop_patch_w, -1)
pooling_idx = np.concatenate([resize_idx, pooling_idx])
image_grid = [np.array([resized_h, resized_w, h, w])]
return (np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), pooling_idx)
class MolmoAct2ImagesKwargs(ImagesKwargs, total=False):
max_crops: int | None
overlap_margins: list[int] | None
crop_mode: str | None
patch_size: int | None
pooling_size: list[int] | None
class MolmoAct2ImageProcessor(BaseImageProcessor):
r"""
Constructs a MolmoAct2 image processor that preprocesses images for the model.
Args:
size (`dict[str, int]` *optional*, defaults to `{"height": 378, "width": 378}`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use when resizing the image.
image_mean (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
image_std (`float` or `list[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `8`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `[4, 4]`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to 14):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `[2, 2]`):
The pooling size of the vision adapter.
"""
model_input_names = ["pixel_values", "image_token_pooling", "image_grids", "image_num_crops"]
def __init__(
self,
size: dict[str, int] | None = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
max_crops: int = 8,
overlap_margins: list[int] = [4, 4],
crop_mode: str = "overlap-and-resize-c2",
patch_size: int = 14,
pooling_size: list[int] = [2, 2],
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 378, "width": 378}
size = get_size_dict(size, default_to_square=True)
self.size = size
self.resample = resample
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_convert_rgb = do_convert_rgb
self.max_crops = max_crops
self.overlap_margins = overlap_margins
self.crop_mode = crop_mode
self.patch_size = patch_size
self.pooling_size = pooling_size
def preprocess(
self,
images: ImageInput,
size: dict[str, int] | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
max_crops: int | None = None,
overlap_margins: list[int] | None = None,
crop_mode: str | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Args:
images (`ImageInput`):
Image to preprocess.
size (`dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
max_crops (`int`, *optional*, defaults to `self.max_crops`):
Maximum number of crops to use per image.
overlap_margins (`list[int]`, *optional*, defaults to `self.overlap_margins`):
Overlap margins to use.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values`: The preprocessed images.
- `image_token_pooling`: The indices of the patches in `crops` to pool for each token in `image_tokens`.
- `image_grids`: The image grids.
- `image_num_crops`: The number of crops for each image.
"""
if size is not None:
if "height" not in size or "width" not in size:
raise ValueError("size must contain 'height' and 'width' keys.")
else:
size = {**self.size}
base_image_input_size = [size["height"], size["width"]]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
max_crops = max_crops or self.max_crops
overlap_margins = overlap_margins or self.overlap_margins
crop_mode = crop_mode or self.crop_mode
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
if images is not None:
images = self.fetch_images(images)
images = make_flat_list_of_images(images)
if images is not None and not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
data = {}
if images is not None:
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
batch_num_crops = []
for image in images:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
image,
max_crops,
overlap_margins,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
crop_mode,
)
batch_grids.append(image_grid)
batch_crops.append(crops)
batch_pooled_patches_idx.append(pooled_idx)
batch_num_crops.append(crops.shape[0])
pixel_values = np.concatenate(batch_crops, 0)
image_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
image_grids = np.concatenate(batch_grids, 0)
image_num_crops = np.array(batch_num_crops)
data.update(
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2ImageProcessor.register_for_auto_class()
@@ -0,0 +1,748 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Inference utilities for MolmoAct2"""
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from collections.abc import Iterable, Sequence
import torch
from torch.nn import functional as F
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
@dataclass
class _ActionFlowInputs:
trajectory: torch.Tensor
context: Any
modulations: Sequence[Any]
action_dim_is_pad: torch.Tensor | None
@dataclass
class _ActionFlowCudaGraph:
key: tuple[Any, ...]
graph: torch.cuda.CUDAGraph
static_inputs: _ActionFlowInputs
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphLayerStage:
residual: torch.Tensor
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphPostStage:
graph: torch.cuda.CUDAGraph
attn_context: torch.Tensor
@dataclass
class _DepthDecodeCudaGraph:
cache_key: tuple[Any, ...]
pre_graph: torch.cuda.CUDAGraph
token_ids: torch.Tensor
cos: torch.Tensor
sin: torch.Tensor
positions: torch.Tensor
stages: Sequence[_DepthDecodeCudaGraphLayerStage]
post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
output: torch.Tensor
@dataclass
class _DepthDecodeCudaGraphSpec:
eligible: bool
cache_key_prefix: tuple[Any, ...]
num_hidden_layers: int
head_dim: int
num_attention_heads: int
def _cache_seq_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return 0
seq_len = past_key_values.get_seq_length()
if torch.is_tensor(seq_len):
return int(seq_len.item())
return int(seq_len)
def _cache_max_len_int(past_key_values: Cache | None) -> int:
if past_key_values is None:
return -1
max_len = past_key_values.get_max_cache_shape()
if torch.is_tensor(max_len):
return int(max_len.item())
return int(max_len)
def _iter_cache_key_values(
past_key_values: Cache,
) -> Iterable[tuple[torch.Tensor | None, torch.Tensor | None]]:
layers = getattr(past_key_values, "layers", None)
if layers is not None:
for layer in layers:
yield getattr(layer, "keys", None), getattr(layer, "values", None)
return
for layer in past_key_values:
yield layer[0], layer[1]
class _DepthDecodeStaticLayerCache:
is_compileable = False
is_sliding = False
def __init__(self, max_cache_len: int) -> None:
self.max_cache_len = int(max_cache_len)
self.cumulative_length = 0
self.keys: torch.Tensor | None = None
self.values: torch.Tensor | None = None
def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
bsz, n_heads = key_states.shape[:2]
self.keys = torch.empty(
(bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
dtype=key_states.dtype,
device=key_states.device,
)
self.values = torch.empty(
(bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
dtype=value_states.dtype,
device=value_states.device,
)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
*args,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.keys is None:
self._allocate(key_states, value_states)
start = self.cumulative_length
end = start + key_states.shape[-2]
if end > self.max_cache_len:
raise RuntimeError(f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}.")
self.keys[:, :, start:end, :].copy_(key_states)
self.values[:, :, start:end, :].copy_(value_states)
self.cumulative_length = end
return self.keys[:, :, :end, :], self.values[:, :, :end, :]
def get_seq_length(self) -> int:
return self.cumulative_length
def get_max_cache_shape(self) -> int:
return -1
def reset(self) -> None:
self.cumulative_length = 0
class _DepthDecodeStaticCache(Cache):
def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
text_config = config.get_text_config(decoder=True)
super().__init__(
layers=[
_DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
for _ in range(text_config.num_hidden_layers)
]
)
def get_seq_length(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
return self.layers[layer_idx].get_max_cache_shape()
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class ActionCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.enabled = True
self.action_flow_graph: _ActionFlowCudaGraph | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
action_model = self.model
if not self.enabled:
return False
if action_model.training or action_model._require_action_expert().training:
return False
if inputs.trajectory.device.type != "cuda":
return False
def all_on_cuda():
yield inputs.trajectory
for k, v in inputs.context.kv_contexts:
yield k
yield v
for t in (
inputs.context.cross_mask,
inputs.context.self_mask,
inputs.context.valid_action,
inputs.action_dim_is_pad,
):
if t is not None:
yield t
if inputs.context.rope_cache is not None:
yield from inputs.context.rope_cache
for step in inputs.modulations:
yield step.conditioning
for block_modulation in step.block_modulations:
yield from block_modulation
yield from step.final_modulation
return all(t.device.type == "cuda" for t in all_on_cuda())
def run_action_flow(
self,
inputs: _ActionFlowInputs,
steps: int,
run_loop,
) -> torch.Tensor:
key = _cuda_graph_key(inputs, steps)
cache = self.action_flow_graph
if cache is None or cache.key != key:
static_inputs = _clone_static_inputs(inputs)
graph, output = _capture_cuda_graph(
lambda: run_loop(static_inputs, steps),
inputs.trajectory.device,
after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
)
cache = _ActionFlowCudaGraph(
key=key,
graph=graph,
static_inputs=static_inputs,
output=output,
)
self.action_flow_graph = cache
else:
_copy_inputs_(cache.static_inputs, inputs)
cache.graph.replay()
return cache.output.clone()
class DepthDecodeCudaGraphManager:
def __init__(self, model: Any) -> None:
self.model = model
self.backbone = model.model
self.enabled = True
self.graph: _DepthDecodeCudaGraph | None = None
self.graph_spec: _DepthDecodeCudaGraphSpec | None = None
def set_enabled(self, enabled: bool) -> None:
self.enabled = bool(enabled)
def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
return _DepthDecodeStaticCache(
config=self.model.config.text_config,
max_cache_len=max_cache_len,
)
def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
static = self.graph_spec
if static is None:
cfg = self.backbone.transformer.config
rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
static = _DepthDecodeCudaGraphSpec(
eligible=(
not cfg.norm_after
and cfg.rope_scaling_layers is None
and getattr(rotary_emb, "rope_type", None) == "default"
and cfg._attn_implementation == "sdpa"
),
cache_key_prefix=(
cfg.hidden_size,
cfg.num_attention_heads,
cfg.num_key_value_heads,
cfg.head_dim,
cfg.num_hidden_layers,
cfg.use_qk_norm,
cfg.qk_norm_type,
cfg._attn_implementation,
),
num_hidden_layers=cfg.num_hidden_layers,
head_dim=cfg.head_dim,
num_attention_heads=cfg.num_attention_heads,
)
self.graph_spec = static
return static
def can_use(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
) -> bool:
if not self.enabled or self.model.training or self.backbone.transformer.training:
return False
if next_input_ids.device.type != "cuda":
return False
if next_input_ids.ndim != 2 or next_input_ids.shape[0] != 1 or next_input_ids.shape[1] != 1:
return False
if not isinstance(past_key_values, _DepthDecodeStaticCache):
return False
if not torch.is_tensor(attention_bias) or attention_bias.device != next_input_ids.device:
return False
return self._depth_decode_spec().eligible
def _depth_decode_key(
self,
next_input_ids: torch.Tensor,
attention_bias: torch.Tensor,
) -> tuple[Any, ...]:
device = next_input_ids.device
return (
self._depth_decode_spec().cache_key_prefix,
device.type,
device.index,
self.model.lm_head.weight.dtype,
attention_bias.shape[-1],
)
def _select_depth_decode_rope(self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int) -> None:
emb = self.backbone.transformer.rotary_emb
cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
def _depth_decode_pre_layer(
self,
layer_idx: int,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
residual = hidden_states
hidden_states = block.attn_norm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, attention.head_dim)
qkv = attention.att_proj(hidden_states)
query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
value_states = value_states.view(hidden_shape)
apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
if apply_qk_norm and not norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape)
if norm_after_view:
query_states = attention.q_norm(query_states)
key_states = attention.k_norm(key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)
return residual, query_states, key_states, value_states
def _depth_decode_pre0(
self,
token_ids: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
inputs_embeds = self.model._embed_base_tokens(token_ids)
return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
def _depth_decode_post_layer(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
block = self.backbone.transformer.blocks[layer_idx]
attention = block.self_attn
input_shape = residual.shape[:-1]
attn_output = attn_context.reshape(*input_shape, -1).contiguous()
attn_output = attention.attn_out(attn_output)
hidden_states = residual + block.dropout(attn_output)
residual = hidden_states
hidden_states = block.ff_norm(hidden_states)
hidden_states = block.mlp(hidden_states)
hidden_states = residual + block.dropout(hidden_states)
return hidden_states
def _depth_decode_post_and_pre_next(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
def _depth_decode_last_post(
self,
layer_idx: int,
residual: torch.Tensor,
attn_context: torch.Tensor,
) -> torch.Tensor:
hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
return self.backbone.transformer.ln_f(hidden_states)
def _build_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
text_config = self.backbone.transformer.config
device = next_input_ids.device
dtype = self.model.lm_head.weight.dtype
static = self._depth_decode_spec()
num_layers = static.num_hidden_layers
head_dim = static.head_dim
max_cache_len = int(attention_bias.shape[-1])
max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
self.backbone.transformer.prepare_rope_cache(device=device, max_seq_len=max_rope_len)
token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
sin = torch.empty_like(cos)
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
context_shape = (1, 1, static.num_attention_heads, head_dim)
token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(cos, sin, past_length=past_length)
pre_graph, pre_output = _capture_cuda_graph(
lambda: self._depth_decode_pre0(token_ids, cos, sin),
device,
)
stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
post_graphs = []
for layer_idx in range(num_layers - 1):
stage = stages[-1]
attn_context = torch.empty(context_shape, device=device, dtype=dtype)
graph, output = _capture_cuda_graph(
lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
self._depth_decode_post_and_pre_next(
layer_idx,
stage.residual,
attn_context,
cos,
sin,
)
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context))
stages.append(_DepthDecodeCudaGraphLayerStage(*output))
last_stage = stages[-1]
last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
last_graph, last_output = _capture_cuda_graph(
lambda: self._depth_decode_last_post(
num_layers - 1,
last_stage.residual,
last_attn_context,
),
device,
)
post_graphs.append(_DepthDecodeCudaGraphPostStage(graph=last_graph, attn_context=last_attn_context))
return _DepthDecodeCudaGraph(
cache_key=self._depth_decode_key(next_input_ids, attention_bias),
pre_graph=pre_graph,
token_ids=token_ids,
cos=cos,
sin=sin,
positions=positions,
stages=tuple(stages),
post_graphs=tuple(post_graphs),
output=last_output,
)
def _get_depth_decode_graph(
self,
next_input_ids: torch.Tensor,
*,
past_length: int,
attention_bias: torch.Tensor,
) -> _DepthDecodeCudaGraph:
key = self._depth_decode_key(next_input_ids, attention_bias)
decode_graph = self.graph
if decode_graph is None or decode_graph.cache_key != key:
decode_graph = self._build_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
self.graph = decode_graph
else:
decode_graph.token_ids.copy_(next_input_ids)
self._select_depth_decode_rope(decode_graph.cos, decode_graph.sin, past_length=past_length)
return decode_graph
def _run_depth_decode_attention_core(
self,
layer_idx: int,
stage: _DepthDecodeCudaGraphLayerStage,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
cache_position: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
attention = self.backbone.transformer.blocks[layer_idx].self_attn
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
stage.key,
stage.value,
layer_idx,
cache_kwargs,
)
key_states = _repeat_kv(key_states, attention.num_key_value_groups)
value_states = _repeat_kv(value_states, attention.num_key_value_groups)
attn_output = F.scaled_dot_product_attention(
stage.query,
key_states,
value_states,
attn_mask=attention_bias,
dropout_p=0.0,
is_causal=False,
)
return attn_output.transpose(1, 2)
def run(
self,
next_input_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
past_length: int,
) -> tuple[torch.Tensor, Cache]:
end = past_length + 1
decode_graph = self._get_depth_decode_graph(
next_input_ids,
past_length=past_length,
attention_bias=attention_bias,
)
cache_position = decode_graph.positions[past_length:end]
attention_bias_q = attention_bias[:, :, past_length:end, :end]
decode_graph.pre_graph.replay()
for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
attn_context = self._run_depth_decode_attention_core(
layer_idx,
decode_graph.stages[layer_idx],
past_key_values=past_key_values,
attention_bias=attention_bias_q,
cache_position=cache_position,
cos=decode_graph.cos,
sin=decode_graph.sin,
)
post_graph.attn_context.copy_(attn_context)
post_graph.graph.replay()
return decode_graph.output, past_key_values
def _cuda_graph_tensor_signature(
tensor: torch.Tensor | None,
) -> tuple[Any, ...] | None:
if tensor is None:
return None
return (
tuple(tensor.shape),
tuple(tensor.stride()),
str(tensor.dtype),
str(tensor.device),
)
def _cuda_graph_context_signature(context: Any) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
sig(context.cross_mask),
sig(context.self_mask),
sig(context.valid_action),
None if context.rope_cache is None else tuple(sig(t) for t in context.rope_cache),
)
def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return tuple(
(
sig(step.conditioning),
tuple(tuple(sig(t) for t in block_modulation) for block_modulation in step.block_modulations),
tuple(sig(t) for t in step.final_modulation),
)
for step in modulations
)
def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> tuple[Any, ...]:
sig = _cuda_graph_tensor_signature
return (
sig(inputs.trajectory),
_cuda_graph_context_signature(inputs.context),
_cuda_graph_modulation_signature(inputs.modulations),
sig(inputs.action_dim_is_pad),
int(steps),
)
def _clone_static_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None:
return None
static = torch.empty_strided(
tuple(tensor.shape),
tuple(tensor.stride()),
device=tensor.device,
dtype=tensor.dtype,
)
static.copy_(tensor)
return static
def _clone_static_context(context: Any) -> Any:
rope_cache = None
if context.rope_cache is not None:
rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
return context.__class__(
kv_contexts=tuple((_clone_static_tensor(k), _clone_static_tensor(v)) for k, v in context.kv_contexts),
cross_mask=_clone_static_tensor(context.cross_mask),
self_mask=_clone_static_tensor(context.self_mask),
valid_action=_clone_static_tensor(context.valid_action),
rope_cache=rope_cache,
)
def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
return tuple(
step.__class__(
conditioning=_clone_static_tensor(step.conditioning),
block_modulations=tuple(
tuple(_clone_static_tensor(t) for t in block_modulation)
for block_modulation in step.block_modulations
),
final_modulation=tuple(_clone_static_tensor(t) for t in step.final_modulation),
)
for step in modulations
)
def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
return _ActionFlowInputs(
trajectory=_clone_static_tensor(inputs.trajectory),
context=_clone_static_context(inputs.context),
modulations=_clone_static_modulations(inputs.modulations),
action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
)
def _copy_context_(dst: Any, src: Any) -> None:
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
dst_k.copy_(src_k)
dst_v.copy_(src_v)
if src.cross_mask is not None:
dst.cross_mask.copy_(src.cross_mask)
if src.self_mask is not None:
dst.self_mask.copy_(src.self_mask)
if src.valid_action is not None:
dst.valid_action.copy_(src.valid_action)
if src.rope_cache is not None:
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
dst_tensor.copy_(src_tensor)
def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
dst.trajectory.copy_(src.trajectory)
_copy_context_(dst.context, src.context)
if src.action_dim_is_pad is not None:
dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (_rotate_half(q) * sin)
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed
def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _capture_cuda_graph(
fn,
device: torch.device,
*,
after_warmup=None,
) -> tuple[torch.cuda.CUDAGraph, Any]:
warmup_stream = torch.cuda.Stream(device=device)
warmup_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(warmup_stream):
fn()
torch.cuda.current_stream(device).wait_stream(warmup_stream)
if after_warmup is not None:
after_warmup()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output = fn()
return graph, output
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,431 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
Processor class for MolmoAct2.
"""
from typing import Optional, Union
import dataclasses
import numpy as np
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
Unpack,
ProcessingKwargs,
ProcessorMixin,
)
from transformers.feature_extraction_utils import BatchFeature
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
from transformers.utils import logging
from transformers import AutoTokenizer
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
logger = logging.get_logger(__name__)
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
IM_START_TOKEN = f"<im_start>"
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
FRAME_START_TOKEN = f"<frame_start>"
IM_END_TOKEN = f"<im_end>"
FRAME_END_TOKEN = f"<frame_end>"
IM_COL_TOKEN = f"<im_col>"
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
IMAGE_TOKENS = [
IMAGE_PATCH_TOKEN,
IM_COL_TOKEN,
IM_START_TOKEN,
LOW_RES_IMAGE_START_TOKEN,
FRAME_START_TOKEN,
IM_END_TOKEN,
FRAME_END_TOKEN,
IMAGE_LOW_RES_TOKEN,
]
class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
"""MolmoAct2 processor kwargs"""
images_kwargs: MolmoAct2ImagesKwargs
videos_kwargs: MolmoAct2VideoProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": True,
},
"videos_kwargs": {"return_metadata": True},
}
class MolmoAct2Processor(ProcessorMixin):
attributes = ["image_processor", "video_processor", "tokenizer"]
optional_attributes = [
"chat_template",
"time_mode",
"image_use_col_tokens",
"use_single_crop_col_tokens",
"use_single_crop_start_token",
"video_use_col_tokens",
"use_frame_special_tokens",
]
image_processor_class = "AutoImageProcessor"
video_processor_class = "AutoVideoProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor: MolmoAct2ImageProcessor = None,
video_processor: MolmoAct2VideoProcessor = None,
tokenizer: AutoTokenizer = None,
chat_template: str | None = None,
image_use_col_tokens: bool | None = True,
use_single_crop_col_tokens: bool | None = None,
use_single_crop_start_token: bool | None = True,
video_use_col_tokens: bool | None = False,
use_frame_special_tokens: bool | None = True,
**kwargs,
) -> None:
super().__init__(
image_processor,
video_processor,
tokenizer,
chat_template=chat_template,
)
self.image_use_col_tokens = image_use_col_tokens
self.use_single_crop_col_tokens = use_single_crop_col_tokens
self.use_single_crop_start_token = use_single_crop_start_token
self.video_use_col_tokens = video_use_col_tokens
self.use_frame_special_tokens = use_frame_special_tokens
self.image_placeholder_token = IMAGE_PROMPT
self.video_placeholder_token = VIDEO_PROMPT
self.image_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in IMAGE_TOKENS]
def get_image_tokens(self, image_grid: np.ndarray):
resized_h, resized_w, height, width = image_grid
if int(height) == 0 or int(width) == 0:
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
]
return np.concatenate(joint)
per_row = np.full(width, IMAGE_PATCH_TOKEN)
if self.image_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[IM_START_TOKEN],
np.tile(per_row, [height]),
[IM_END_TOKEN],
]
per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
use_single_crop_col_tokens = (
self.image_use_col_tokens
if self.use_single_crop_col_tokens is None
else self.use_single_crop_col_tokens
)
image_start_token = LOW_RES_IMAGE_START_TOKEN if self.use_single_crop_start_token else IM_START_TOKEN
if use_single_crop_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
joint = [
[image_start_token],
np.tile(per_row, [resized_h]),
[IM_END_TOKEN],
] + joint
return np.concatenate(joint)
def get_video_string(
self,
video_grid: np.ndarray,
timestamps: np.ndarray,
):
if self.use_frame_special_tokens:
start_token_id = FRAME_START_TOKEN
end_token_id = FRAME_END_TOKEN
else:
start_token_id = IM_START_TOKEN
end_token_id = IM_END_TOKEN
num_frames, h, w = video_grid
video_string: str = ""
for frame_idx, frame_time in enumerate(timestamps):
# `per-frame-compact` time mode
prev_space = " " if frame_idx > 0 else ""
frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
video_string += frame_prefix
per_row = np.full(w, IMAGE_PATCH_TOKEN)
if self.video_use_col_tokens:
per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
extra_tokens = np.tile(per_row, [h])
video_tokens = [
[start_token_id],
extra_tokens,
[end_token_id],
]
video_string += "".join(np.concatenate(video_tokens, 0))
return video_string
def insert_bos(
self,
input_ids: np.ndarray,
attention_mask: np.ndarray,
bos_token_id: int,
pad_token_id: int,
):
"""
Args:
input_ids: [B, S] array with left padding
attention_mask: [B, S] array (0 for pad, 1 for valid)
bos_token_id: int
pad_token_id: int
Returns:
input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
attention_mask_out: same shape as input_ids_out
"""
need_to_expand = len(input_ids.shape) == 1
if need_to_expand:
input_ids = input_ids[None, :]
attention_mask = attention_mask[None, :]
B, S = input_ids.shape
# Handle zero-length sequence
if S == 0:
new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
if bos_already_present:
if need_to_expand:
input_ids = input_ids[0]
attention_mask = attention_mask[0]
return input_ids, attention_mask
else:
new_input_ids = np.full((B, S + 1), pad_token_id, dtype=input_ids.dtype)
new_attention_mask = np.zeros((B, S + 1), dtype=attention_mask.dtype)
src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
tgt_idx = src_idx + 1 # shit right
batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
# flatten valid_positions
flat_vals = input_ids[valid_mask]
flat_batch = batch_idx[valid_mask]
flat_tgt = tgt_idx[valid_mask]
new_input_ids[flat_batch, flat_tgt] = flat_vals
new_attention_mask[flat_batch, flat_tgt] = 1
insert_pos = first_valid_index
new_input_ids[np.arange(B), insert_pos] = bos_token_id
new_attention_mask[np.arange(B), insert_pos] = 1
if need_to_expand:
new_input_ids = new_input_ids[0]
new_attention_mask = new_attention_mask[0]
return new_input_ids, new_attention_mask
def __call__(
self,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
images: ImageInput = None,
videos: VideoInput = None,
**kwargs: Unpack[MolmoAct2ProcessorKwargs],
) -> BatchFeature:
"""
Args:
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
videos (`dict[str, Any]` or `list[dict[str, Any]]`):
The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
- `"frames"`: `np.ndarray` of shape (T, H, W, 3)
- `"timestamps"`: `np.ndarray` of shape (T,)
- `"sampled_fps"`: `float` (optional)
- `"sampling_augmentation"`: `str` (optional)
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
`BatchFeature`: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
Returned when `images` is not `None`.
- **image_grids** -- Grids of images. Returned when `images` is not `None`.
- **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
Returned when `videos` is not `None`.
- **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
MolmoAct2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
image_grids = image_inputs["image_grids"]
else:
image_inputs = {}
image_grids = None
if videos is not None:
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grids = videos_inputs["video_grids"]
# If user has not requested video metadata, pop it
if "return_metadata" not in kwargs:
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
else:
videos_inputs = {}
video_grids = None
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
if image_grids is not None:
index = 0
for i in range(len(text)):
num_images = text[i].count(self.image_placeholder_token)
image_grids_i = image_grids[index : index + num_images]
for image_grid in image_grids_i:
image_tokens = self.get_image_tokens(image_grid)
image_string = "".join(image_tokens)
text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
index += num_images
if video_grids is not None:
index = 0
for i in range(len(text)):
num_videos = text[i].count(self.video_placeholder_token)
assert num_videos in {0, 1}, "At most one video is supported for now"
video_grids_i = video_grids[index : index + num_videos]
metadata_i = video_metadata[index : index + num_videos]
for video_grid, metadata in zip(video_grids_i, metadata_i):
video_string = self.get_video_string(
video_grid,
metadata.timestamps,
)
text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
index += num_videos
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
input_ids = text_inputs["input_ids"]
attention_mask = text_inputs["attention_mask"]
input_ids = np.array(input_ids)
attention_mask = np.array(attention_mask)
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
input_ids, attention_mask = self.insert_bos(
input_ids, attention_mask, bos, self.tokenizer.pad_token_id
)
if return_mm_token_type_ids:
image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
text_inputs["token_type_ids"] = token_type_ids.tolist()
text_inputs["input_ids"] = input_ids.tolist()
text_inputs["attention_mask"] = attention_mask.tolist()
return BatchFeature(
data={**text_inputs, **image_inputs, **videos_inputs},
tensor_type=return_tensors,
)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`list[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
MolmoAct2Processor.register_for_auto_class()
@@ -0,0 +1,997 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Video processor class for MolmoAct2"""
from functools import partial
import os
import warnings
from contextlib import redirect_stdout
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Union
from collections.abc import Callable
import numpy as np
import requests
import einops
import torch
import torchvision.transforms
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
SizeDict,
validate_kwargs,
)
from transformers.video_utils import (
VideoInput,
is_valid_video,
make_batched_videos,
make_batched_metadata,
VideoMetadata,
)
from transformers.processing_utils import Unpack, VideosKwargs
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import (
is_av_available,
is_decord_available,
is_torchcodec_available,
is_yt_dlp_available,
TensorType,
logging,
to_numpy,
)
logger = logging.get_logger(__name__)
MAX_VIDEO_FPS = 8
def normalize_image(
image: np.ndarray,
image_mean: list[float],
image_std: list[float],
) -> np.ndarray:
if np.allclose(image_mean, [0.5, 0.5, 0.5]) and np.allclose(image_std, [0.5, 0.5, 0.5]):
return image * np.asarray(2.0, dtype=np.float32) - np.asarray(1.0, dtype=np.float32)
image -= np.array(image_mean, dtype=np.float32)[None, None, :]
image /= np.array(image_std, dtype=np.float32)[None, None, :]
return image
def resize_image(
image: np.ndarray,
desired_output_size: list[int],
resample: PILImageResampling,
) -> np.ndarray:
if len(image.shape) == 3:
is_video = False
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
else:
is_video = True
image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2])
dtype = image.dtype
if torch.is_floating_point(image):
in_min = 0.0
in_max = 1.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
)
in_min = 0.0
in_max = 255.0
resized = torchvision.transforms.Resize(
desired_output_size,
resample,
antialias=False,
)(image)
resized = torch.clip(resized, 0, 255).to(dtype)
resized = resized.to(torch.float32)
resized = (resized - in_min) / (in_max - in_min)
if is_video:
resized = torch.permute(resized, [0, 2, 3, 1]).numpy()
else:
resized = torch.permute(resized, [1, 2, 0]).numpy()
return resized
def build_resized_image(
image: np.ndarray,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
) -> tuple[np.ndarray, np.ndarray]:
resized = resize_image(
image,
base_image_input_size,
resample,
)
resized = normalize_image(resized, image_mean, image_std)
if len(resized.shape) == 3:
resized = np.expand_dims(resized, 0)
crop_patch_w = base_image_input_size[1] // image_patch_size
crop_patch_h = base_image_input_size[0] // image_patch_size
resize_idx = np.arange(crop_patch_w * crop_patch_h).reshape([crop_patch_h, crop_patch_w])
return resized, resize_idx
def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray:
"""Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
if len(array.shape) == 3:
n_crops, h, w = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
array = np.transpose(array, [0, 1, 3, 2, 4])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size])
return array
else:
n_crops, h, w, c = array.shape
h_patches = h // patch_size
w_patches = w // patch_size
array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
array = np.transpose(array, [0, 1, 3, 2, 4, 5])
array = np.reshape(array, [n_crops, h_patches * w_patches, patch_size * patch_size * c])
return array
def arange_for_pooling(
idx_arr: np.ndarray,
pool_h: int,
pool_w: int,
) -> np.ndarray:
h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0]
w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1]
idx_arr = np.pad(
idx_arr,
[[h_pad // 2, (h_pad + 1) // 2], [w_pad // 2, (w_pad + 1) // 2]],
mode="constant",
constant_values=-1,
)
return einops.rearrange(idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w)
def image_to_patches_and_grids(
image: ImageInput,
base_image_input_size: list[int],
resample: PILImageResampling,
image_mean: list[float],
image_std: list[float],
image_patch_size: int,
image_pooling_w: int,
image_pooling_h: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
:return image_grids, the shape of each image after pooling
:return crops, the image crops to processes with the ViT
:return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the
patches in `crops` to pool for that token, masked with -1
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
pooling_w = image_pooling_w
pooling_h = image_pooling_h
resized, resize_idx = build_resized_image(
image,
base_image_input_size,
resample,
image_mean,
image_std,
image_patch_size,
)
pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w)
h, w = pooling_idx.shape[:2]
pooling_idx = pooling_idx.reshape([-1, pooling_h * pooling_w])
image_grid = [h, w]
return (
image_grid,
batch_pixels_to_patches(resized, image_patch_size),
pooling_idx,
)
def get_candidate_target_fps(
video_fps: int | float,
sampling_fps: int | float,
max_fps: int | float = MAX_VIDEO_FPS,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(f"video_fps and sampling_fps must be positive (got {video_fps}, {sampling_fps})")
if video_fps % sampling_fps != 0:
raise ValueError(f"sampling_fps={sampling_fps} must divide video_fps={video_fps}.")
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
def read_video_decord(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the Decord backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import from decord
import importlib
decord = importlib.import_module("decord")
vr = decord.VideoReader(uri=video_path, ctx=decord.cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
time_stamps = vr.get_frame_timestamp(list(range(len(vr))))
duration = time_stamps[-1][1] - time_stamps[0][0]
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="decord",
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
target_timestamps = np.array(target_timestamps)
offset = time_stamps[0, 0]
ix = np.searchsorted(time_stamps[:, 1], target_timestamps + offset, side="right")
ix = np.minimum(ix, len(time_stamps) - 1)
video = vr.get_batch(ix).asnumpy()
metadata.update(
{
"frames_indices": target_timestamps * video_fps,
"height": video.shape[1],
"width": video.shape[2],
}
)
return video, metadata
def read_video_torchcodec(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using torchcodec decoder.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
torchcodec = importlib.import_module("torchcodec")
decoder = torchcodec.decoders.VideoDecoder(
video_path,
# Interestingly `exact` mode takes less than approximate when we load the whole video
seek_mode="exact",
# Allow FFmpeg decide on the number of threads for efficiency
num_ffmpeg_threads=0,
)
# If the first frame starts at > 0, we effectively clip the video starting at that time
# since (most) video players would also skip to that time
time_offset = decoder.metadata.begin_stream_seconds_from_content
# Note this duration does assume we started playing at `time_offset`
duration = decoder.metadata.duration_seconds
metadata = VideoMetadata(
total_num_frames=decoder.metadata.num_frames,
fps=decoder.metadata.average_fps,
duration=duration,
video_backend="torchcodec",
height=decoder.metadata.height,
width=decoder.metadata.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
# Floating point/rounding issues might cause `target_timestamps` to be very slightly
# out-of-bounds, to handle this we sanity check then clip them
assert all(x >= 0 for x in target_timestamps)
assert all(x < duration + 1e-6 for x in target_timestamps)
# 1e-6 padding since torchcodec can throw out-of-bounds errors even if you ask for the
# exact boundary value, we should still get the first/last frame anyway
max_timestamp = decoder.metadata.end_stream_seconds_from_content - 1e-6
min_timestamp = decoder.metadata.begin_stream_seconds_from_content + 1e-6
# Note we avoid using numpy ops here to reduce floating precision issues
timestamps = [x + time_offset for x in target_timestamps]
timestamps = [max(min_timestamp, min(max_timestamp, x)) for x in timestamps]
video = (
decoder.get_frames_played_at(timestamps).data.numpy().transpose(0, 2, 3, 1)
) # Convert to THWC format
target_timestamps = np.array(target_timestamps)
metadata.frames_indices = target_timestamps * metadata.fps
return video, metadata
def read_video_pyav(
video_path,
sample_timestamps_fn: Callable,
**kwargs,
) -> np.ndarray:
"""
Decode a video using the PyAV backend.
Args:
video_path (`str`):
Path to the video file.
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
Returns:
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
# Lazy import torchcodec
import importlib
av = importlib.import_module("av")
with av.open(video_path) as container:
video_stream = container.streams.video[0]
fps = video_stream.average_rate or video_stream.guessed_rate
it = container.decode(video=0)
frames = list(it)
stream = container.streams.video[0]
start = frames[0].pts * stream.time_base
container_end = stream.duration
if container_end is not None:
container_end *= stream.time_base
if container_end is None or container_end < frames[-1].pts:
# Some problem with stream duration, so use the frame PTS directly
# and guess the duration of the last frame
end = frames[-1].pts * stream.time_base + 1 / fps
else:
end = container_end
duration = float(end - start)
metadata = VideoMetadata(
total_num_frames=len(frames),
fps=float(fps),
duration=float(duration),
video_backend="pyav",
height=video_stream.height,
width=video_stream.width,
)
target_timestamps = sample_timestamps_fn(metadata=metadata, **kwargs)
offset = float(start)
target_timestamps = np.array(target_timestamps)
end_time_stamps = np.array([float(frame.pts * stream.time_base) for frame in frames[1:]] + [duration])
indices = np.searchsorted(end_time_stamps, target_timestamps + offset, side="right")
indices = np.minimum(indices, len(end_time_stamps) - 1)
video = np.stack(
[frames[i].to_ndarray(format="rgb24", channel_last=True) for i in indices],
axis=0,
)
metadata.frames_indices = target_timestamps * fps
return video, metadata
VIDEO_DECODERS = {
"decord": read_video_decord,
"torchcodec": read_video_torchcodec,
"pyav": read_video_pyav,
}
def load_video(
video: VideoInput,
backend: str = "decord",
sample_timestamps_fn: Callable | None = None,
**kwargs,
):
"""
Loads `video` to a numpy array.
Args:
video (`VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
backend (`str`, *optional*, defaults to `"decord"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", ""torchcodec"]. Defaults to "decord".
sample_timestamps_fn (`Callable`):
A callable function that will return timestamps at which the video should be sampled.
"""
# Early exit if provided an array or `PIL` frames
if not isinstance(video, str):
metadata = [None] * len(video)
return video, metadata
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
if not is_yt_dlp_available():
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
# Lazy import from yt_dlp
import importlib
yt_dlp = importlib.import_module("yt_dlp")
buffer = BytesIO()
with redirect_stdout(buffer), yt_dlp.YoutubeDL() as f:
f.download([video])
bytes_obj = buffer.getvalue()
file_obj = BytesIO(bytes_obj)
elif video.startswith("http://") or video.startswith("https://"):
file_obj = BytesIO(requests.get(video, timeout=10).content)
elif os.path.isfile(video):
file_obj = video
else:
raise TypeError(
"Incorrect format used for video. Should be an url linking to an video or a local path."
)
# can also load with decord, but not cv2/torchvision
# both will fail in case of url links
video_is_url = video.startswith("http://") or video.startswith("https://")
if video_is_url and backend == "opencv":
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
if (
(not is_decord_available() and backend == "decord")
or (not is_torchcodec_available() and backend == "torchcodec")
or (not is_av_available() and backend == "pyav")
):
raise ImportError(
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
f"Make sure to install {backend} before loading the video."
)
video_decoder = VIDEO_DECODERS[backend]
video, metadata = video_decoder(file_obj, sample_timestamps_fn, **kwargs)
return video, metadata
def get_target_fps(
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: tuple[float],
) -> float:
"""
Get the target fps that best spans the video and has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if "uniform" in frame_sample_mode:
if num_frames_sampled_at_fps > max_frames:
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames, choose the one with higher density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
def get_frame_times_and_chosen_fps(selected_target_fps, total_frames, max_frames, video_fps):
if selected_target_fps is None:
frame_indices = np.linspace(0, total_frames, max_frames, endpoint=False, dtype=int)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
class MolmoAct2VideoProcessorKwargs(VideosKwargs, total=False):
patch_size: int | None
pooling_size: list[int] | None
frame_sample_mode: str | None
max_fps: int | None
sampling_fps: int | None
class MolmoAct2VideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BILINEAR
size = {"height": 378, "width": 378}
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
patch_size = 14
pooling_size = [3, 3]
do_sample_frames = True
frame_sample_mode = "uniform_last_frame"
max_fps = 2
sampling_fps = 2
valid_kwargs = MolmoAct2VideoProcessorKwargs
model_input_names = ["pixel_values_videos", "video_token_pooling", "video_grids"]
def __init__(self, **kwargs: Unpack[MolmoAct2VideoProcessorKwargs]):
super().__init__(**kwargs)
if self.size is not None and (
self.size.get("height", None) is None or self.size.get("width", None) is None
):
raise ValueError("size must contain 'height' and 'width' keys.")
def _further_process_kwargs(
self,
size: SizeDict | None = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if size is not None and ("height" not in size or "width" not in size):
raise ValueError("size must contain 'height' and 'width' keys.")
return super()._further_process_kwargs(size=size, **kwargs)
def sample_times(
self,
metadata: VideoMetadata,
frame_sample_mode: str,
num_frames: int,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Time-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
man_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
duration = metadata.duration or metadata.total_num_frames / metadata.fps
if frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
# Try larger and larger FPSs until we hit one that can't span the video
target_fps = candidate_target_fps[0]
for candidate_fps in candidate_target_fps[1:]:
if num_frames / candidate_fps < duration:
break
target_fps = candidate_fps
times = np.arange(0, num_frames) / target_fps
times = times[times < duration]
return times
elif frame_sample_mode == "uniform_last_frame":
if max_fps is not None:
max_duration = (num_frames - 1) / max_fps # -1 to include the last frame
if max_duration < duration:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
else:
times = np.arange(0.0, stop=duration, step=1 / max_fps)
times = np.concatenate([times, [duration]], axis=0)
assert len(times) <= num_frames
else:
times = np.linspace(0, duration, num=num_frames, endpoint=True, dtype=np.float64)
return times
else:
raise NotImplementedError(frame_sample_mode)
def sample_frames(
self,
metadata: VideoMetadata,
frame_sample_mode: str | None = None,
num_frames: int | None = None,
max_fps: int | None = None,
sampling_fps: int | None = None,
**kwargs,
) -> np.ndarray:
"""
Frame-based sampling if an array video is passed
Args:
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
frame_sample_mode (`str`, *optional*):
Mode to sample frames. Defaults to `self.frame_sample_mode`.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
max_fps (`int`, *optional*):
Maximum frames per second to sample.
sampling_fps (`int`, *optional*):
Sampling frames per second. Defaults to `self.sampling_fps`.
Used when `frame_sample_mode` is `"fps"`.
"""
frame_sample_mode = frame_sample_mode or self.frame_sample_mode
num_frames = num_frames or self.num_frames
sampling_fps = sampling_fps or self.sampling_fps
total_num_frames = metadata.total_num_frames
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
duration = total_num_frames / metadata.fps
if total_num_frames <= 2:
return np.arange(total_num_frames).astype(int)
if duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(metadata.fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate([float_indices, [total_num_frames - 1]], axis=0)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
return indices
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
return indices
elif frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(metadata.fps, sampling_fps)
selected_target_fps = get_target_fps(
metadata.fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
metadata.fps,
)
return indices
else:
raise NotImplementedError(frame_sample_mode)
def fetch_videos(self, video_url_or_urls: str | list[str] | list[list[str]], sample_timestamps_fn=None):
"""
Convert a single or a list of urls into the corresponding `np.array` objects.
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
if (not is_decord_available()) and (not is_torchcodec_available()) and (not is_av_available()):
raise ImportError(
"MolmoAct2VideoProcessor requires `decord`, `torchcodec`, or `av` to be installed."
)
if is_decord_available():
backend = "decord"
elif is_torchcodec_available():
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `torchcodec`."
)
backend = "torchcodec"
else:
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `PyAV`."
)
backend = "pyav"
if isinstance(video_url_or_urls, list):
return list(
zip(
*[
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
for x in video_url_or_urls
]
)
)
else:
return load_video(video_url_or_urls, backend=backend, sample_timestamps_fn=sample_timestamps_fn)
def _decode_and_sample_videos(
self,
videos: VideoInput,
video_metadata: VideoMetadata | dict,
do_sample_frames: bool | None = None,
sample_indices_fn: Callable | None = None,
sample_timestamps_fn: Callable | None = None,
):
"""
Decode input videos and sample frames if needed.
"""
videos = make_batched_videos(videos)
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
# Framed-based sampling if an array video is passed
# Otherwise, time-based sampling with decoding
if is_valid_video(videos[0]) and do_sample_frames:
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
sampled_videos = []
sampled_metadata = []
for video, metadata in zip(videos, video_metadata):
indices = sample_indices_fn(metadata=metadata)
metadata.frames_indices = indices
sampled_videos.append(video[indices])
sampled_metadata.append(metadata)
videos = sampled_videos
video_metadata = sampled_metadata
elif not is_valid_video(videos[0]):
if sample_indices_fn is None:
logger.warning(
"do_sample_frames is False, but video array is not provided: "
"Will decode the video and sample frames using MolmoAct2's default sampling mode"
)
if isinstance(videos[0], list):
raise ValueError("A list of images is not supported for video input!")
else:
videos, video_metadata = self.fetch_videos(videos, sample_timestamps_fn=sample_timestamps_fn)
return videos, video_metadata
def _prepare_input_videos(
self,
videos: VideoInput,
**kwargs,
) -> list[np.ndarray]:
processed_videos = [to_numpy(video) for video in videos]
return processed_videos
def preprocess(
self,
videos: VideoInput,
**kwargs: Unpack[MolmoAct2VideoProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
do_sample_frames = kwargs.pop("do_sample_frames")
video_metadata = kwargs.pop("video_metadata")
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
sample_timestamps_fn = partial(self.sample_times, **kwargs)
videos, video_metadata = self._decode_and_sample_videos(
videos,
video_metadata=video_metadata,
do_sample_frames=do_sample_frames,
sample_indices_fn=sample_indices_fn,
sample_timestamps_fn=sample_timestamps_fn,
)
videos = self._prepare_input_videos(videos=videos)
kwargs = self._further_process_kwargs(**kwargs)
return_metadata = kwargs.pop("return_metadata")
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
if return_metadata:
preprocessed_videos["video_metadata"] = video_metadata
return preprocessed_videos
def _preprocess(
self,
videos: list[np.ndarray],
size: SizeDict | None = None,
resample: PILImageResampling | None = None,
image_mean: float | list[float] | None = None,
image_std: float | list[float] | None = None,
do_convert_rgb: bool | None = None,
patch_size: int | None = None,
pooling_size: list[int] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess a video for the model.
Args:
videos (`VideoInput`):
Video to preprocess.
size (`SizeDict`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use when resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
patch_size (`int`, *optional*, defaults to `self.patch_size`):
The spatial patch size of the vision encoder.
pooling_size (`list[int]`, *optional*, defaults to `self.pooling_size`):
The pooling size of the vision adapter.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
Returns:
A `BatchFeature` containing the following keys:
- `pixel_values_videos`: The preprocessed videos.
- `video_token_pooling`: The indices of the patches in `crops` to pool for each token in `video_tokens`.
- `video_grids`: The video grids.
"""
if size.height is None or size.width is None:
raise ValueError("size must contain 'height' and 'width' keys.")
base_image_input_size = [size.height, size.width]
resample = resample or self.resample
image_mean = image_mean or self.image_mean
image_std = image_std or self.image_std
do_convert_rgb = do_convert_rgb or self.do_convert_rgb
patch_size = patch_size or self.patch_size
pooling_size = pooling_size or self.pooling_size
image_pooling_h, image_pooling_w = pooling_size
batch_grids = []
batch_crops = []
batch_pooled_patches_idx = []
for video in videos:
all_crops = []
pooled_patches_idx = []
for frame in video:
image_grid, crops, pooled_idx = image_to_patches_and_grids(
frame,
base_image_input_size,
resample,
image_mean,
image_std,
patch_size,
image_pooling_w,
image_pooling_h,
)
offset = sum(np.prod(x.shape[:2]) for x in all_crops)
pooled_idx_with_offset = np.where(pooled_idx >= 0, pooled_idx + offset, pooled_idx)
pooled_patches_idx.append(pooled_idx_with_offset)
all_crops.append(crops)
video_grid = np.array([len(video), image_grid[0], image_grid[1]])
all_crops = np.concatenate(all_crops, 0)
pooled_patches_idx = np.concatenate(pooled_patches_idx, 0)
batch_grids.append(video_grid)
batch_crops.append(all_crops)
batch_pooled_patches_idx.append(pooled_patches_idx)
video_grids = np.stack(batch_grids, 0)
pixel_values_videos = np.concatenate(batch_crops, 0)
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
data = dict(
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
return BatchFeature(data, tensor_type=return_tensors)
MolmoAct2VideoProcessor.register_for_auto_class()
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+34 -22
View File
@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -258,15 +265,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -274,13 +282,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
+34 -22
View File
@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -255,15 +262,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -271,13 +279,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_vla_jepa_README.md
+23
View File
@@ -0,0 +1,23 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"make_vla_jepa_pre_post_processors",
]
@@ -0,0 +1,337 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions
@@ -0,0 +1,154 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
gripper_dim: int = 6
gripper_threshold: float = 0.5
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
@@ -0,0 +1,629 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if reinitialized:
logging.warning(
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model
@@ -0,0 +1,155 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes a gripper dimension after unnormalization.
Maps continuous value to {-1, 1}: > threshold -1, <= threshold 1 (matches starVLA convention).
Only applied when action has more dimensions than gripper_dim.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -0,0 +1,117 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)
@@ -0,0 +1,418 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
return actions
@ProcessorStepRegistry.register("delta_actions_processor")
@ProcessorStepRegistry.register("relative_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
+2
View File
@@ -20,12 +20,14 @@ from .factory import (
make_reward_pre_post_processors as make_reward_pre_post_processors,
)
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
__all__ = [
# Configuration classes
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
"TOPRewardConfig",
# Base class
+16 -2
View File
@@ -25,6 +25,7 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
from .topreward.configuration_topreward import TOPRewardConfig
@@ -38,7 +39,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
Args:
name: The name of the reward model. Supported names are "reward_classifier",
"sarm", "topreward".
"sarm", "robometer", "topreward".
Returns:
The reward model class corresponding to the given name.
@@ -54,6 +55,10 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "robometer":
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
return RobometerRewardModel
elif name == "topreward":
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
@@ -74,7 +79,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
Args:
reward_type: The type of the reward model. Supported types include
"reward_classifier", "sarm", "topreward".
"reward_classifier", "sarm", "robometer", "topreward".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -87,6 +92,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RewardClassifierConfig(**kwargs)
elif reward_type == "sarm":
return SARMConfig(**kwargs)
elif reward_type == "robometer":
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
else:
@@ -168,6 +175,13 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(reward_cfg, RobometerConfig):
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
return make_robometer_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, TOPRewardConfig):
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_robometer import RobometerConfig
from .modeling_robometer import RobometerRewardModel
from .processor_robometer import make_robometer_pre_post_processors
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]
@@ -0,0 +1,320 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compute per-frame Robometer progress and success curves for a LeRobot dataset.
For each episode, builds per-frame sub-samples using the frame-steps
strategy from the Robometer eval server: for each original frame ``t``,
linspace-subsample ``[0, t]`` into ``K`` frames (default 4, matching
``NUM_SUBSAMPLED_FRAMES`` in the eval server), run one forward through
the Robometer processor + model, and keep the last-frame progress value.
All sub-samples are the same size ``K`` so they batch cleanly.
The parquet uses the same schema as SARM's
:mod:`lerobot.rewards.sarm.compute_rabc_weights` so existing consumers
:class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
``progress_sparse``) and the progress-overlay script in
``examples/dataset/create_progress_videos.py`` work without modification.
Usage:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes with batching
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
from lerobot.types import TransitionKey
DEFAULT_OUTPUT_FILENAME = "robometer_progress.parquet"
# Upstream Robometer eval server uses K=4 for frame-steps sub-samples.
DEFAULT_NUM_SUBSAMPLED_FRAMES = 4
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
"""Read ``reward_model_path`` from parquet metadata if available."""
if not parquet_path.exists():
return None
try:
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
if metadata and b"reward_model_path" in metadata:
return metadata[b"reward_model_path"].decode()
except Exception: # nosec B110
return None
return None
def _resolve_task(sample: dict[str, Any], default: str) -> str:
"""Best-effort task extraction from a dataset sample."""
task = sample.get("task")
if isinstance(task, str) and task:
return task
return default
def _build_subsample_indices(num_frames: int, num_subsampled_frames: int) -> list[np.ndarray]:
"""Frame-steps linspace expansion.
For each ``t in [0, num_frames - 1]`` returns ``num_subsampled_frames``
indices from ``np.linspace(0, t, num_subsampled_frames)`` the first
and last frames are always included. Each entry is a fixed-size array
so the model can batch them.
"""
return [np.linspace(0, t, num_subsampled_frames).round().astype(np.int64) for t in range(num_frames)]
def compute_robometer_progress(
dataset_repo_id: str,
reward_model_path: str,
output_path: str | None = None,
device: str = "cuda",
batch_size: int = 32,
num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES,
episodes: list[int] | None = None,
image_key: str | None = None,
) -> Path:
"""Run Robometer over a dataset and write per-frame progress + success."""
logging.info(f"Loading Robometer: {reward_model_path}")
config = RobometerConfig(pretrained_path=reward_model_path, device=device)
if image_key is not None:
config.image_key = image_key
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
model.to(device).eval()
encoder = RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=num_subsampled_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
)
image_key = config.image_key
logging.info(f"Loading dataset: {dataset_repo_id}")
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
logging.info(f"Processing {len(episode_indices)} episode(s)")
all_index: list[int] = []
all_episode: list[int] = []
all_frame: list[int] = []
all_progress: list[float] = []
for episode_idx in tqdm(episode_indices, desc="Episodes"):
ep = dataset.meta.episodes[episode_idx]
ep_start = int(ep["dataset_from_index"])
ep_end = int(ep["dataset_to_index"])
num_frames = ep_end - ep_start
if num_frames <= 0:
continue
first_sample = dataset[ep_start]
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
ep_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
sub_indices = _build_subsample_indices(num_frames, num_subsampled_frames)
progress_per_frame = np.zeros(num_frames, dtype=np.float32)
for start in tqdm(range(0, num_frames, batch_size), desc=f" Ep {episode_idx}", leave=False):
end = min(start + batch_size, num_frames)
frames_batch = torch.stack([ep_frames[sub_indices[i]] for i in range(start, end)])
transition = {
TransitionKey.OBSERVATION: {image_key: frames_batch},
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
}
encoded = encoder(transition)
obs = encoded[TransitionKey.OBSERVATION]
batch = {
key: value.to(device) if isinstance(value, torch.Tensor) else value
for key, value in obs.items()
}
with torch.no_grad():
rewards = model.compute_reward(batch)
progress_per_frame[start:end] = rewards.cpu().numpy()
for local in range(num_frames):
all_index.append(ep_start + local)
all_episode.append(episode_idx)
all_frame.append(local)
all_progress.append(float(progress_per_frame[local]))
if device.startswith("cuda"):
torch.cuda.empty_cache()
table = pa.table(
{
"index": np.asarray(all_index, dtype=np.int64),
"episode_index": np.asarray(all_episode, dtype=np.int64),
"frame_index": np.asarray(all_frame, dtype=np.int64),
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
}
).replace_schema_metadata({b"reward_model_path": reward_model_path.encode()})
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out)
logging.info(f"Saved {len(table)} frame values to {out}")
progress_arr = np.asarray(all_progress, dtype=np.float32)
if progress_arr.size:
logging.info(
f"Progress: mean={float(progress_arr.mean()):.4f}, "
f"std={float(progress_arr.std()):.4f}, "
f"min={float(progress_arr.min()):.4f}, "
f"max={float(progress_arr.max()):.4f}"
)
return out
def main():
parser = argparse.ArgumentParser(
description="Compute per-frame Robometer progress curves for RA-BC weighting.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Dense per-frame progress for one episode
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--episodes 0
# All episodes, smaller batches for memory-constrained GPUs
python -m lerobot.rewards.robometer.compute_rabc_weights \\
--dataset-repo-id lerobot/libero_10_image \\
--reward-model-path lerobot/Robometer-4B \\
--batch-size 16
""",
)
parser.add_argument(
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
)
parser.add_argument(
"--reward-model-path", type=str, default=None, help="Robometer checkpoint repo id or local path."
)
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
parser.add_argument(
"--batch-size", type=int, default=32, help="Sub-samples per Qwen forward (default: 32)."
)
parser.add_argument(
"--num-subsampled-frames",
type=int,
default=DEFAULT_NUM_SUBSAMPLED_FRAMES,
help=f"Frames per sub-sample (default: {DEFAULT_NUM_SUBSAMPLED_FRAMES}, matches eval server).",
)
parser.add_argument(
"--episodes", type=int, nargs="+", default=None, help="Process only these episode indices."
)
parser.add_argument(
"--image-key", type=str, default=None, help="Image observation key (default: from config)."
)
parser.add_argument(
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
reward_model_path = args.reward_model_path
if reward_model_path is None:
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
parquet_path = Path(temp_dataset.root) / DEFAULT_OUTPUT_FILENAME
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
if reward_model_path:
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
else:
raise ValueError(
"--reward-model-path is required (no existing parquet with model metadata found)."
)
output_path = compute_robometer_progress(
dataset_repo_id=args.dataset_repo_id,
reward_model_path=reward_model_path,
output_path=args.output_path,
device=args.device,
batch_size=args.batch_size,
num_subsampled_frames=args.num_subsampled_frames,
episodes=args.episodes,
image_key=args.image_key,
)
print(f"\nRobometer progress saved to: {output_path}")
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = DEFAULT_OUTPUT_FILENAME
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(output_path),
path_in_repo=hub_path,
repo_id=args.dataset_repo_id,
repo_type="dataset",
)
print(
"Successfully uploaded to: "
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
)
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
print(" rabc_head_mode: sparse")
else:
print("\nTo use in training, add to your config:")
print(" use_rabc: true")
print(f" rabc_progress_path: {output_path}")
print(" rabc_head_mode: sparse")
if __name__ == "__main__":
main()
@@ -0,0 +1,158 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.configs.rewards import RewardModelConfig
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoTokenizer
else:
AutoConfig = None # type: ignore[assignment]
AutoTokenizer = None # type: ignore[assignment]
# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time.
# The order is part of the data contract: upstream resized ``embed_tokens``
# after adding these tokens in this exact order, so changing the set or order
# would silently misalign the saved embedding rows with their token ids.
# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream
# heads (never read at inference) but still occupy rows the checkpoint expects.
ROBOMETER_SPECIAL_TOKENS = (
"<|split_token|>",
"<|reward_token|>",
"<|pref_token|>",
"<|sim_token|>",
"<|prog_token|>",
)
@RewardModelConfig.register_subclass("robometer")
@dataclass
class RobometerConfig(RewardModelConfig):
"""Configuration for the Robometer reward model."""
pretrained_path: str | None = "lerobot/Robometer-4B"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
reward_output: str = "progress" # "progress" or "success"
success_threshold: float = 0.5
license: str | None = "apache-2.0"
tags: list[str] | None = field(
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
)
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
torch_dtype: str = "bfloat16"
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
average_temporal_patches: bool = True
frame_pooling: str = "mean" # "mean" | "boundary" | "attention"
frame_pooling_attn_temperature: float = 1.0
progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete"
progress_discrete_bins: int = 10
# Serialised Qwen backbone config (post-resize). Always populated by
# ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it
# is non-empty after construction. Saved into ``config.json`` automatically
# by the base ``_save_pretrained``.
vlm_config: dict[str, Any] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"REWARD": NormalizationMode.IDENTITY,
}
)
def __post_init__(self) -> None:
super().__post_init__()
if self.reward_output not in {"progress", "success"}:
raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}")
if self.max_frames is not None and self.max_frames < 1:
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
if self.frame_pooling not in {"mean", "boundary", "attention"}:
raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}")
if self.frame_pooling_attn_temperature <= 0:
raise ValueError("frame_pooling_attn_temperature must be > 0")
if self.progress_loss_type not in {"l1", "l2", "discrete"}:
raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}")
if self.use_per_frame_progress_token and not self.use_multi_image:
raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True")
if self.image_key not in self.input_features:
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
# Deterministically populate ``vlm_config`` so it is non-empty after
# construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives
# ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize
# vocab the published ``Robometer-4B`` checkpoint was saved with.
if not self.vlm_config:
require_package("transformers", extra="robometer")
vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict()
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
text_config = vlm.get("text_config")
if not isinstance(text_config, dict):
raise ValueError(
f"Backbone config for {self.base_model_id!r} has no nested `text_config`; "
"Robometer expects a Qwen-VL-style config."
)
text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)
self.vlm_config = vlm
@property
def use_discrete_progress(self) -> bool:
"""Whether the progress head outputs distribution logits over bins."""
return self.progress_loss_type.lower() == "discrete"
@property
def vlm_backbone_config(self):
"""Reconstruct the Qwen backbone config from :attr:`vlm_config`."""
require_package("transformers", extra="robometer")
config_dict = deepcopy(self.vlm_config)
model_type = config_dict.pop("model_type", None)
if model_type is None:
raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config")
return AutoConfig.for_model(model_type, **config_dict)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> None:
return None
@property
def reward_delta_indices(self) -> None:
return None
def validate_features(self) -> None:
if self.image_key not in self.input_features:
raise ValueError(f"Robometer requires image input feature {self.image_key!r}")
@@ -0,0 +1,481 @@
# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons.
Paper: https://arxiv.org/abs/2603.02115
Project: https://robometer.github.io
Original code: https://github.com/aliang8/robometer
Model: https://huggingface.co/robometer/Robometer-4B
Robometer is a general-purpose, video-language-input reward model built on
``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction
objective:
- A frame-level progress loss anchoring reward magnitude on expert data.
- A trajectory-comparison preference loss imposing global ordering constraints
across trajectories sharing the same instruction.
To support downstream RL it also predicts a frame-level binary success. The
training prompt inserts three learnable tokens:
- ``<|prog_token|>`` after each frame to read per-frame progress and success.
- ``<|pref_token|>`` at the end to read pairwise preference (training-only).
- ``<|split_token|>`` between two trajectories in preference samples
(training-only).
Progress is modeled as a categorical distribution over ``progress_discrete_bins``
uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate
is recovered as the softmax-weighted mean of those centers see
:func:`convert_bins_to_continuous`.
This LeRobot port is **inference-only**: the preference head is preserved in
the state dict for byte-equivalence with the published ``Robometer-4B``
checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`,
which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd
success probability depending on :attr:`RobometerConfig.reward_output`.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import torch
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModelForImageTextToText
else:
AutoModelForImageTextToText = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
# Namespace for Robometer's pre-encoded Qwen-VL observation tensors.
ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer."
ROBOMETER_QWEN_INPUT_KEYS = (
"input_ids",
"attention_mask",
"pixel_values",
"pixel_values_videos",
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"mm_token_type_ids",
)
ROBOMETER_METADATA_KEYS = (
"prog_token_id",
"vision_start_token_id",
"vision_end_token_id",
"video_merge_size",
)
ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS
def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor:
"""Collapse per-bin logits into a single value in ``[0, 1]``.
The discrete progress head outputs ``num_bins`` logits per frame. Bins are
evenly spaced centers in ``[0, 1]``; the continuous prediction is the
softmax-weighted mean of those centers.
"""
bin_probs = torch.softmax(bin_logits, dim=-1)
num_bins = bin_logits.shape[-1]
bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype)
return (bin_probs * bin_centers).sum(dim=-1)
def _squeeze_last_safe(x: Tensor) -> Tensor:
"""Drop a trailing singleton dim only when present."""
return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x
def _torch_dtype(name: str) -> torch.dtype:
dtype = getattr(torch, name, None)
if isinstance(dtype, torch.dtype):
return dtype
raise ValueError(f"Unknown torch dtype: {name!r}")
class RobometerPredictionHead(nn.Sequential):
"""Small MLP head used for Robometer's progress / success / preference outputs."""
def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None:
layers: list[nn.Module] = [
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, output_size),
]
if with_sigmoid:
layers.append(nn.Sigmoid())
super().__init__(*layers)
def decode_progress_outputs(
progress_logits: Tensor | None,
success_logits: Tensor | None,
*,
is_discrete_mode: bool,
) -> dict[str, list[list[float]]]:
"""Decode RBM head outputs into per-frame floats.
Args:
progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities.
is_discrete_mode: if True the progress logits get a softmax over bins
and are projected onto bin centers via :func:`convert_bins_to_continuous`.
Returns:
Dict with ``progress_pred`` and ``success_probs``, each a list of
length ``B`` of per-frame float lists.
"""
progress_pred: list[list[float]] = []
success_probs: list[list[float]] = []
if progress_logits is not None:
for sample_logits in progress_logits:
if is_discrete_mode:
continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu())
progress_pred.append(continuous.flatten().tolist())
else:
progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist())
if success_logits is not None:
for sample_logits in success_logits:
success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist())
return {"progress_pred": progress_pred, "success_probs": success_probs}
class RobometerRewardModel(PreTrainedRewardModel):
"""Robometer (RBM) reward model — inference-only LeRobot port.
Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three
prediction heads from the paper (progress, success, preference). At
inference time only the progress and success heads are queried; the
preference head is kept on the module so the published ``Robometer-4B``
safetensors load unchanged.
"""
name = "robometer"
config_class = RobometerConfig
def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None:
require_package("transformers", extra="robometer")
super().__init__(config)
self.config = config
# Two backbone-build paths (EO-1 style, branched on ``pretrained_path``):
#
# - Fresh training (``pretrained_path is None``): download the base
# Qwen weights and resize the embed table to match
# ``vlm_config.text_config.vocab_size`` — populated deterministically
# in ``RobometerConfig.__post_init__`` as
# ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``
#
# - Loading a saved checkpoint (``pretrained_path`` is set): rebuild
# the empty architecture from ``vlm_config`` via
# ``AutoModelForImageTextToText.from_config`` so the subsequent
# ``model.safetensors`` load is a direct fill of the right shape —
# no redundant Qwen weight download.
torch_dtype = _torch_dtype(config.torch_dtype)
if config.pretrained_path is None:
self.model = AutoModelForImageTextToText.from_pretrained(
config.base_model_id,
dtype=torch_dtype,
trust_remote_code=True,
)
target_vocab = config.vlm_config["text_config"]["vocab_size"]
self.model.resize_token_embeddings(target_vocab)
else:
self.model = AutoModelForImageTextToText.from_config(
config.vlm_backbone_config,
dtype=torch_dtype,
trust_remote_code=True,
)
# All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`.
# Falls back to the top-level `hidden_size` so future non-multimodal
# variants would still resolve.
backbone_config = self.model.config
text_config = getattr(backbone_config, "text_config", None)
hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None
if hidden_size is None:
hidden_size = getattr(backbone_config, "hidden_size", None)
if hidden_size is None:
raise AttributeError(
f"Could not infer hidden_size from backbone config of {config.base_model_id}"
)
hidden_dim = int(hidden_size)
# Robometer's three prediction heads + frame-pool attention.
progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1
self.progress_head = RobometerPredictionHead(
hidden_dim,
progress_output,
dropout=dropout,
with_sigmoid=not config.use_discrete_progress,
)
self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False)
# Match the dtype of the loaded base model so weight loading is a no-op cast.
model_dtype = next(self.model.parameters()).dtype
self.progress_head.to(dtype=model_dtype)
self.preference_head.to(dtype=model_dtype)
self.success_head.to(dtype=model_dtype)
self.frame_pool_attn.to(dtype=model_dtype)
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
inputs = {
key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"]
for key in ROBOMETER_INPUT_KEYS
if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch
}
if "input_ids" not in inputs:
raise KeyError(
f"Robometer batch missing pre-encoded inputs (expected "
f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the "
"RobometerEncoderProcessorStep ran before `compute_reward`."
)
device = next(self.model.parameters()).device
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
self.eval()
with torch.no_grad():
progress_logits, success_logits = self._compute_rbm_logits(inputs)
decoded = decode_progress_outputs(
progress_logits,
success_logits,
is_discrete_mode=self.config.use_discrete_progress,
)
values = (
decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"]
)
rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values])
if self.config.reward_output == "success":
rewards = (rewards > self.config.success_threshold).float()
else:
# Match upstream Robometer's ``extract_rewards_from_output``: per-frame
# progress predictions are clamped to ``[0, 1]`` before being returned.
rewards = rewards.clamp(0.0, 1.0)
return rewards.to(self.config.device or "cpu")
def _compute_rbm_logits(
self,
inputs: dict[str, Any],
) -> tuple[Tensor, Tensor]:
"""Run the Qwen3-VL backbone and apply Robometer's heads.
``inputs`` is the encoded batch produced by
:class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well
as Robometer-specific metadata (``prog_token_id``,
``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``)
the metadata is popped here so the rest can be forwarded straight to
the Qwen model.
Returns ``(progress_logits, success_logits)``. Shapes:
- ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
- ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time).
"""
prog_token_id = inputs.pop("prog_token_id", None)
vision_start_token_id = inputs.pop("vision_start_token_id", None)
vision_end_token_id = inputs.pop("vision_end_token_id", None)
video_merge_size = inputs.pop("video_merge_size", 14)
# Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the
# full hidden-state tuple and take the last layer. This matches the
# `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main).
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
hidden_state = (
outputs.hidden_states[-1]
if getattr(outputs, "hidden_states", None)
else outputs.last_hidden_state
)
input_ids = inputs["input_ids"]
if self.config.use_per_frame_progress_token:
if prog_token_id is None:
raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)")
return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id)
if self.config.use_multi_image:
if vision_start_token_id is None or vision_end_token_id is None:
raise KeyError(
"`vision_start_token_id` / `vision_end_token_id` missing in batch "
"(run RobometerEncoderProcessorStep first)"
)
return self._process_multi_image_frames(
hidden_state,
input_ids,
start_id=vision_start_token_id,
end_id=vision_end_token_id,
)
video_grid_thw = inputs.get("video_grid_thw")
if video_grid_thw is None:
raise ValueError("video_grid_thw is required for video-mode Robometer inference")
if vision_start_token_id is None:
raise KeyError("`vision_start_token_id` missing in batch")
return self._process_video_frames(
hidden_state,
input_ids,
video_grid_thw,
start_id=vision_start_token_id,
merge_size=video_merge_size,
)
def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]:
"""Apply progress + success heads to a tensor of frame embeddings."""
progress_out = self.progress_head(frame_embeddings)
progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out)
success = _squeeze_last_safe(self.success_head(frame_embeddings))
return progress, success
def _process_token_extraction(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
prog_token_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success from ``<|prog_token|>`` positions."""
token_mask = input_ids == prog_token_id
batch_indices, positions = token_mask.nonzero(as_tuple=True)
if positions.numel() == 0:
raise ValueError("`<|prog_token|>` not found in any sequence")
per_sample_hidden = [
hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0])
]
progress_list, success_list = [], []
for embeddings in per_sample_hidden:
if embeddings.shape[0] == 0:
raise ValueError("`<|prog_token|>` missing in a sequence")
progress, success = self._apply_heads_to_hidden_states(embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _process_multi_image_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
*,
start_id: int,
end_id: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in multi-image mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
frame_embeddings = self._extract_hidden_states_from_token_pairs(
seq_hidden, seq_ids, start_id, end_id
)
progress, success = self._apply_heads_to_hidden_states(frame_embeddings)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
def _extract_hidden_states_from_token_pairs(
self,
hidden_state: Tensor,
input_ids: Tensor,
start_id: int,
end_id: int,
) -> Tensor:
start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0]
end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
if start_positions.numel() != end_positions.numel():
raise ValueError(
f"Mismatched vision token counts: {start_positions.numel()} start vs "
f"{end_positions.numel()} end"
)
frames: list[Tensor] = []
for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True):
if start >= end:
raise ValueError(f"Invalid vision token pair: start={start} end={end}")
patch_tokens = hidden_state[start + 1 : end]
if patch_tokens.shape[0] == 0:
frames.append((hidden_state[start] + hidden_state[end]) / 2.0)
continue
pooling = self.config.frame_pooling
if pooling == "mean":
frames.append(patch_tokens.mean(dim=0))
elif pooling == "boundary":
frames.append(patch_tokens[-1])
else: # attention
scores = (
self.frame_pool_attn(patch_tokens).squeeze(-1)
/ self.config.frame_pooling_attn_temperature
)
weights = torch.softmax(scores, dim=0).unsqueeze(-1)
frames.append((weights * patch_tokens).sum(dim=0))
return torch.stack(frames)
def _process_video_frames(
self,
hidden_state: Tensor,
input_ids: Tensor,
video_grid_thw: Tensor,
*,
start_id: int,
merge_size: int,
) -> tuple[Tensor, Tensor]:
"""Per-frame progress/success in video mode (Qwen-VL)."""
progress_list, success_list = [], []
for batch_idx in range(input_ids.shape[0]):
seq_ids = input_ids[batch_idx]
seq_hidden = hidden_state[batch_idx]
start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0]
if start_positions.numel() == 0:
raise ValueError("`<|vision_start|>` not found in sequence")
t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist())
tokens_per_frame = (h_dim * w_dim) // (merge_size**2)
cursor = start_positions[0].item()
frame_embeddings: list[Tensor] = []
for _ in range(t_dim):
if self.config.average_temporal_patches:
patch = seq_hidden[cursor : cursor + tokens_per_frame]
frame_embeddings.append(patch.mean(dim=0))
else:
frame_embeddings.append(seq_hidden[cursor + tokens_per_frame])
cursor += tokens_per_frame
stacked = torch.stack(frame_embeddings)
progress, success = self._apply_heads_to_hidden_states(stacked)
progress_list.append(progress)
success_list.append(success)
return torch.stack(progress_list), torch.stack(success_list)
@@ -0,0 +1,338 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Robometer pre/post processing pipelines."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from PIL import Image
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
policy_action_to_transition,
)
from lerobot.rewards.robometer.configuration_robometer import (
ROBOMETER_SPECIAL_TOKENS,
RobometerConfig,
)
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor
else:
AutoProcessor = None
PROGRESS_PROMPT = (
"The task for the robot is '{task}'. Given the trajectory video, predict "
"the task progress at each frame, how far along the robot is towards "
"completing the task, a float between 0 and 1, where 0 is the starting "
"state and 1 is when the task is completed. If the robot is not "
"performing the same task, predict 0 progress."
)
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
if frames.ndim != 4:
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
if frames.dtype != np.uint8:
frames = np.clip(frames, 0, 255).astype(np.uint8)
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array."""
if max_frames is not None:
video = video[-max_frames:]
if video.shape[1] in (1, 3):
video = video.permute(0, 2, 3, 1)
elif video.shape[-1] not in (1, 3):
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
array = video.detach().cpu().numpy()
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
array = array * 255.0
return np.clip(array, 0, 255).astype(np.uint8)
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
if task is None:
task = default
if task is None:
raise KeyError("Robometer expected a task description in complementary data")
if isinstance(task, str):
return [task] * batch_size
if isinstance(task, tuple):
task = list(task)
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}")
if len(task) == 1 and batch_size > 1:
return task * batch_size
if len(task) != batch_size:
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
return task
@dataclass
@ProcessorStepRegistry.register(name="robometer_encoder")
class RobometerEncoderProcessorStep(ProcessorStep):
"""Encode raw frames + task into Qwen-VL tensors for the Robometer model.
Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and
registers Robometer's special tokens on the tokenizer. The matching
embedding resize happens model-side in
:meth:`RobometerRewardModel.__init__`.
At call time the step reads:
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
- ``complementary_data[task_key]``: a string or list of strings.
and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}<name>"]`` for:
- the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``,
``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ...
- Robometer-specific token ids consumed by the model heads:
``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``,
``video_merge_size``.
"""
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
image_key: str = OBS_IMAGES + ".top"
task_key: str = "task"
default_task: str | None = None
max_frames: int | None = 8
use_multi_image: bool = True
use_per_frame_progress_token: bool = True
max_length: int = 1024
_processor: Any = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
require_package("transformers", extra="robometer")
require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils")
self._processor = AutoProcessor.from_pretrained(
self.base_model_id,
trust_remote_code=True,
do_sample_frames=False,
padding_side="right",
)
# Register Robometer's special tokens on the tokenizer. The matching
# embedding resize happens model-side in `RobometerRewardModel.__init__`.
tokenizer = self._processor.tokenizer
# Qwen tokenizers may not define a pad token, but batched prompts/videos
# require padding, so reuse EOS as the padding token.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
for token in ROBOMETER_SPECIAL_TOKENS:
if token not in tokenizer.get_vocab():
tokenizer.add_special_tokens({"additional_special_tokens": [token]})
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if not isinstance(observation, dict):
raise ValueError("RobometerEncoderProcessorStep requires an observation dict")
if self.image_key not in observation:
raise KeyError(f"Robometer expected image key {self.image_key!r} in observation")
frames = observation[self.image_key]
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
if tensor.ndim == 4:
tensor = tensor.unsqueeze(1)
elif tensor.ndim != 5:
raise ValueError(
f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
)
batch_size = tensor.shape[0]
tasks = _expand_tasks(
complementary.get(self.task_key, self.default_task),
batch_size=batch_size,
default=self.default_task,
)
samples = [
(_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size)
]
encoded = self.encode_samples(samples)
new_observation = dict(observation)
for key, value in encoded.items():
new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]:
"""Run the Qwen-VL processor on a list of ``(frames, task)`` samples."""
from qwen_vl_utils import process_vision_info
conversations = [self._build_conversation(frames, task) for frames, task in samples]
texts = [
self._processor.apply_chat_template(
msg,
tokenize=False,
add_generation_prompt=False,
add_vision_id=True,
enable_thinking=False,
fps=1,
)
for msg in conversations
]
process_kwargs: dict[str, Any] = {
"return_video_kwargs": True,
"return_video_metadata": True,
}
image_processor = getattr(self._processor, "image_processor", None)
if image_processor is not None and hasattr(image_processor, "patch_size"):
process_kwargs["image_patch_size"] = image_processor.patch_size
image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs)
videos: list[Any] | None = None
video_metadatas: list[Any] | None = None
if video_inputs:
if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2:
videos_seq, metadatas_seq = zip(*video_inputs, strict=False)
videos = list(videos_seq)
video_metadatas = list(metadatas_seq)
else:
videos = list(video_inputs)
processor_kwargs: dict[str, Any] = {
"text": texts,
"images": image_inputs,
"padding": True,
"truncation": False,
"max_length": self.max_length,
"return_tensors": "pt",
"do_resize": False,
}
if videos is not None:
processor_kwargs["videos"] = videos
if video_metadatas is not None:
processor_kwargs["video_metadata"] = video_metadatas
if video_kwargs:
processor_kwargs.update(video_kwargs)
encoded = self._processor(**processor_kwargs)
# Write Robometer-specific token ids and the video patch merge size into
# the encoded batch so `RobometerRewardModel` doesn't need its own
# tokenizer at inference (EO1-style separation: the processor owns the
# tokenizer, the model owns the backbone and heads).
tokenizer = self._processor.tokenizer
encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>")
encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>")
encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>")
video_processor = getattr(self._processor, "video_processor", None)
encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14))
return encoded
def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]:
pil_frames = _frames_to_pil(frames)
prompt = PROGRESS_PROMPT.format(task=task)
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
if self.use_multi_image:
for image in pil_frames:
content.append({"type": "image", "image": image})
if self.use_per_frame_progress_token:
content.append({"type": "text", "text": "<|prog_token|>"})
else:
content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0})
return [{"role": "user", "content": content}]
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"base_model_id": self.base_model_id,
"image_key": self.image_key,
"task_key": self.task_key,
"default_task": self.default_task,
"max_frames": self.max_frames,
"use_multi_image": self.use_multi_image,
"use_per_frame_progress_token": self.use_per_frame_progress_token,
"max_length": self.max_length,
}
def make_robometer_pre_post_processors(
config: RobometerConfig,
dataset_stats: dict[str, dict[str, Any]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
The preprocessor adds a batch dimension if needed, runs Robometer's
encoder, and moves everything to the configured device. The
postprocessor is the identity since Robometer outputs a single reward
tensor.
"""
del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor.
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
AddBatchDimensionProcessorStep(),
RobometerEncoderProcessorStep(
base_model_id=config.base_model_id,
image_key=config.image_key,
task_key=config.task_key,
default_task=config.default_task,
max_frames=config.max_frames,
use_multi_image=config.use_multi_image,
use_per_frame_progress_token=config.use_per_frame_progress_token,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
+4
View File
@@ -23,6 +23,7 @@ from .configs import (
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutStrategyConfig,
@@ -49,6 +50,7 @@ from .inference import (
from .strategies import (
BaseStrategy,
DAggerStrategy,
EpisodicStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
@@ -66,6 +68,8 @@ __all__ = [
"HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig",
"EpisodicStrategy",
"EpisodicStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
+36 -1
View File
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("episodic")
@dataclass
class EpisodicStrategyConfig(RolloutStrategyConfig):
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
Keyboard controls:
Right arrow end current episode or reset phase early
Left arrow discard current episode and re-record
Escape stop recording session
In between episodes:
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
- else, the robot is moved smoothly to the position of the teleop leader.
"""
# This only applies if there are no teleop leaders specified.
# When True (default), moves the robot back to the joint positions captured at startup.
# Otherwise, leave the robot in its current position.
reset_to_initial_position: bool = True
# Whether to turn on or off the leader -> follower smooth handover behavior.
# When False, fallback to follower -> leader handover.
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
smooth_leader_to_follower_handover: bool = True
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
@@ -229,7 +258,13 @@ class RolloutConfig:
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
needs_dataset = isinstance(
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
self.strategy,
(
SentryStrategyConfig,
HighlightStrategyConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
),
)
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
@@ -17,6 +17,7 @@
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .episodic import EpisodicStrategy
from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
@@ -27,6 +28,7 @@ __all__ = [
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"EpisodicStrategy",
"RolloutStrategy",
"SentryStrategy",
"create_strategy",
+14 -69
View File
@@ -56,10 +56,14 @@ from typing import Any
import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.common.control_utils import (
follower_smooth_move_to,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
@@ -171,64 +174,6 @@ class DAggerEvents:
self.upload_requested.clear()
# ---------------------------------------------------------------------------
# Teleoperator helpers
# ---------------------------------------------------------------------------
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def _follower_smooth_move_to(
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm
to the follower, we bring the follower to the teleop's current pose.
Both ``current`` and ``target`` must be in robot-action key space.
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
# ---------------------------------------------------------------------------
# Input device handlers
# ---------------------------------------------------------------------------
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
logger.info("Pausing engine - robot holds position")
engine.pause()
if _teleop_supports_feedback(teleop) and prev_action is not None:
if teleop_supports_feedback(teleop) and prev_action is not None:
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
logger.info("Smooth handover: moving leader arm to follower position")
_teleop_smooth_move_to(teleop, prev_action)
teleop_smooth_move_to(teleop, prev_action)
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
logger.info("Entering correction mode - human teleop control")
if not _teleop_supports_feedback(teleop) and prev_action is not None:
if not teleop_supports_feedback(teleop) and prev_action is not None:
logger.info("Smooth handover: sliding follower to teleop position")
obs = robot.get_observation()
teleop_action = teleop.get_action()
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
target = ctx.processors.robot_action_processor((processed, obs))
_follower_smooth_move_to(robot, prev_action, target)
follower_smooth_move_to(robot, prev_action, target)
# unlock the teleop for human control
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.disable_torque()
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.enable_torque()
elif new_phase == DAggerPhase.AUTONOMOUS:
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
engine.resume()
# release teleop before resuming the policy
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.disable_torque()
# ------------------------------------------------------------------
+335
View File
@@ -0,0 +1,335 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
- Policy drives the robot during each recording episode.
- An optional teleoperator can drive the robot during reset phases so the
operator can bring the environment back to its starting configuration.
If no teleop is connected the robot stays in its current position.
- Keyboard controls:
Right arrow end the current episode or reset phase early
Left arrow discard the current episode and re-record it
Escape stop the recording session
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
"""
from __future__ import annotations
import contextlib
import logging
import time
from lerobot.common.control_utils import (
follower_smooth_move_to,
init_keyboard_listener,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import log_rerun_data
from ..configs import EpisodicStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
logger = logging.getLogger(__name__)
class EpisodicStrategy(RolloutStrategy):
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
follows every episode (except the last) so the operator can manually
reset the environment. During the reset phase, an optional teleoperator
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
The policy state (hidden state, RTC queue, interpolator) is reset at
the start of each recording episode.
Keyboard events:
right arrow end current episode or reset phase early
left arrow discard & re-record current episode
ESC stop the session
"""
config: EpisodicStrategyConfig
def __init__(self, config: EpisodicStrategyConfig) -> None:
super().__init__(config)
self._listener = None
self._events: dict | None = None
def setup(self, ctx: RolloutContext) -> None:
"""Start the inference engine and attach the keyboard listener."""
self._init_engine(ctx)
self._listener, self._events = init_keyboard_listener()
logger.info("Episodic strategy ready")
def run(self, ctx: RolloutContext) -> None:
"""Main multi-episode recording loop."""
cfg = ctx.runtime.cfg
dataset_cfg = cfg.dataset
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
features = ctx.data.dataset_features
fps = cfg.fps
episode_time_s = dataset_cfg.episode_time_s
reset_time_s = dataset_cfg.reset_time_s
num_episodes = dataset_cfg.num_episodes
single_task = dataset_cfg.single_task or cfg.task
play_sounds = cfg.play_sounds
display_compressed = (
True
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
else cfg.display_compressed_images
)
with VideoEncodingManager(dataset):
try:
recorded_episodes = 0
while recorded_episodes < num_episodes and not events["stop_recording"]:
if ctx.runtime.shutdown_event.is_set():
break
# Reset policy state at episode start (discard leftover hidden state / queue)
self._engine.reset()
self._interpolator.reset()
self._engine.resume()
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
self._policy_loop(
ctx=ctx,
robot=robot,
events=events,
features=features,
fps=fps,
control_time_s=episode_time_s,
dataset=dataset,
single_task=single_task,
)
# Reset phase, skip after the last episode (but run when re-recording)
if not events["stop_recording"] and (
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
if teleop:
# Smooth handover so the transition to teleop control is jerk-free.
# For actuated teleops: drive the leader arm to the follower's current
# position so the operator takes over without fighting the arm.
# For non-actuated teleops: slide the follower to the teleop's current
# pose instead, since the leader cannot be driven.
obs = robot.get_observation()
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
if (
teleop_supports_feedback(teleop)
and self.config.smooth_leader_to_follower_handover
):
logger.info("Smooth handover: moving leader arm to follower position")
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
teleop.disable_torque()
else:
logger.info("Smooth handover: sliding follower to teleop position")
teleop_action = teleop.get_action()
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
target = ctx.processors.robot_action_processor((processed, obs))
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
elif self.config.reset_to_initial_position:
# No teleop: return the robot to its startup position.
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
self._reset_loop(
ctx=ctx,
robot=robot,
teleop=teleop,
events=events,
fps=fps,
control_time_s=reset_time_s,
display_data=cfg.display_data,
display_compressed=display_compressed,
)
if events["rerecord_episode"]:
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
# returns to its initial joint positions captured at startup
if not teleop and self.config.reset_to_initial_position:
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
continue
dataset.save_episode()
recorded_episodes += 1
finally:
# Save any frames buffered in the current episode so an unexpected
# exception or KeyboardInterrupt does not silently drop recorded data.
# suppress: save_episode raises if the buffer is empty (nothing to lose).
logger.info("Episodic control loop ended — saving any in-progress episode")
with contextlib.suppress(Exception):
dataset.save_episode()
def _policy_loop(
self,
ctx: RolloutContext,
robot,
events: dict,
features: dict,
fps: float,
control_time_s: float,
dataset,
single_task: str,
) -> None:
"""Policy-driven recording loop for a single episode."""
interpolator = self._interpolator
control_interval = interpolator.get_control_interval(fps)
timestamp = 0.0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
if ctx.runtime.shutdown_event.is_set():
break
obs = robot.get_observation()
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
dt = time.perf_counter() - loop_start
sleep_t = control_interval - dt
if sleep_t < 0:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
"Dataset frames might be dropped and robot control might be unstable. "
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
"3) CPU starvation"
)
precise_sleep(max(sleep_t, 0.0))
timestamp = time.perf_counter() - start_t
def _reset_loop(
self,
ctx: RolloutContext,
robot,
teleop,
events: dict,
fps: float,
control_time_s: float,
display_data: bool,
display_compressed: bool,
) -> None:
"""Reset-phase loop: teleop drives the robot if available, no recording."""
processors = ctx.processors
control_interval = 1.0 / fps
timestamp = 0.0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
if ctx.runtime.shutdown_event.is_set():
break
obs = robot.get_observation()
if teleop is not None:
act = teleop.get_action()
act_teleop = processors.teleop_action_processor((act, obs))
robot_action = processors.robot_action_processor((act_teleop, obs))
robot.send_action(robot_action)
if display_data:
obs_processed = processors.robot_observation_processor(obs)
log_rerun_data(
observation=obs_processed,
action=act_teleop,
compress_images=display_compressed,
)
dt = time.perf_counter() - loop_start
sleep_t = control_interval - dt
precise_sleep(max(sleep_t, 0.0))
timestamp = time.perf_counter() - start_t
def teardown(self, ctx: RolloutContext) -> None:
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
cfg = ctx.runtime.cfg
play_sounds = cfg.play_sounds
log_say("Stop recording", play_sounds, blocking=True)
if not is_headless() and self._listener is not None:
self._listener.stop()
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if (
cfg.dataset is not None
and cfg.dataset.push_to_hub
and ctx.data.dataset is not None
and safe_push_to_hub(
ctx.data.dataset,
tags=cfg.dataset.tags,
private=cfg.dataset.private,
)
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(
ctx.hardware,
return_to_initial_position=cfg.return_to_initial_position,
)
log_say("Exiting", play_sounds)
logger.info("Episodic strategy teardown complete")
+6 -1
View File
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
from .base import BaseStrategy
from .core import RolloutStrategy
from .dagger import DAggerStrategy
from .episodic import EpisodicStrategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
return HighlightStrategy(config)
if config.type == "dagger":
return DAggerStrategy(config)
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
if config.type == "episodic":
return EpisodicStrategy(config)
raise ValueError(
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
)
+90 -11
View File
@@ -105,6 +105,7 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
predicted_latents_callback: Callable[[PreTrainedPolicy], None] | None = None,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
@@ -134,6 +135,9 @@ def rollout(
are returned optionally because they typically take more memory to cache. Defaults to False.
render_callback: Optional rendering callback to be used after the environments are reset, and after
every step.
predicted_latents_callback: Optional callback invoked after every ``select_action`` with the policy
itself. World-model policies (e.g. LingBot-VA) stash predicted video latents on
``policy.last_predicted_latents``; this lets the caller concatenate chunks and decode once.
Returns:
The dictionary described above.
"""
@@ -184,6 +188,8 @@ def rollout(
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
if predicted_latents_callback is not None:
predicted_latents_callback(policy)
action = postprocessor(action)
action_transition = {ACTION: action}
@@ -203,12 +209,22 @@ def rollout(
# available if none of the envs finished.
if "final_info" in info:
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
if isinstance(final_info, dict):
is_success = final_info.get("is_success", [False] * env.num_envs)
successes = (
is_success.tolist()
if hasattr(is_success, "tolist")
else [bool(is_success)] * env.num_envs
)
successes = final_info["is_success"].tolist()
else:
# Gymnasium < 1.0 returns final_info as a per-env sequence/object array,
# with entries set to a dict only for envs that just finished.
successes = []
for item in final_info:
if isinstance(item, dict) and "is_success" in item:
successes.append(bool(item["is_success"]))
else:
successes.append(False)
elif "is_success" in info:
is_success = info["is_success"]
successes = (
@@ -273,6 +289,7 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
save_predicted_video: bool = False,
) -> dict:
"""
Args:
@@ -291,6 +308,11 @@ def eval_policy(
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
# World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config.
save_predicted_video = save_predicted_video or bool(
getattr(getattr(policy, "config", None), "save_predicted_video", False)
)
if not isinstance(policy, PreTrainedPolicy):
exc = ValueError(
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
@@ -334,6 +356,22 @@ def eval_policy(
if max_episodes_rendered > 0:
video_paths: list[str] = []
if save_predicted_video:
if not videos_dir:
raise ValueError("If save_predicted_video is True, videos_dir must be provided.")
predicted_video_paths: list[str] = []
n_predicted_rendered = 0
# Collect predicted-video latents across a rollout (world-model policies only). The latents are
# concatenated and decoded once after the rollout, matching upstream LingBot-VA's visualization path.
def collect_predicted_latents(policy: PreTrainedPolicy):
latents = getattr(policy, "last_predicted_latents", None)
if latents is not None:
pred_latents.append(
latents.detach().to("cpu") if hasattr(latents, "detach") else torch.as_tensor(latents).cpu()
)
policy.last_predicted_latents = None
if return_episode_data:
episode_data: dict | None = None
@@ -345,6 +383,9 @@ def eval_policy(
if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = []
if save_predicted_video:
pred_latents: list[torch.Tensor] = []
if start_seed is None:
seeds = None
else:
@@ -361,6 +402,7 @@ def eval_policy(
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
predicted_latents_callback=collect_predicted_latents if save_predicted_video else None,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -426,6 +468,35 @@ def eval_policy(
threads.append(thread)
n_episodes_rendered += 1
# Maybe save the policy's predicted (imagined) video for this batch's rollout.
if save_predicted_video and len(pred_latents) > 0:
predicted_latent = torch.cat(pred_latents, dim=2)
decoder = getattr(policy, "decode_predicted_latents", None) or getattr(
policy, "_decode_predicted_video", None
)
if decoder is None:
raise AttributeError(
"Policy config requested predicted-video saving, but the policy does not expose "
"`decode_predicted_latents` or `_decode_predicted_video`."
)
predicted_video = decoder(predicted_latent)
if hasattr(predicted_video, "detach"):
predicted_video = predicted_video.detach().to("cpu").numpy()
videos_dir.mkdir(parents=True, exist_ok=True)
predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4"
predicted_video_paths.append(str(predicted_video_path))
thread = threading.Thread(
target=write_video,
args=(
str(predicted_video_path),
predicted_video,
env.unwrapped.metadata["render_fps"],
),
)
thread.start()
threads.append(thread)
n_predicted_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
)
@@ -469,6 +540,9 @@ def eval_policy(
if max_episodes_rendered > 0:
info["video_paths"] = video_paths
if save_predicted_video:
info["predicted_video_paths"] = predicted_video_paths
return info
@@ -600,9 +674,10 @@ class TaskMetrics(TypedDict):
max_rewards: list[float]
successes: list[bool]
video_paths: list[str]
predicted_video_paths: list[str]
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths", "predicted_video_paths")
def eval_one(
@@ -643,6 +718,7 @@ def eval_one(
max_rewards=[ep["max_reward"] for ep in per_episode],
successes=[ep["success"] for ep in per_episode],
video_paths=task_result.get("video_paths", []),
predicted_video_paths=task_result.get("predicted_video_paths", []),
)
@@ -689,6 +765,7 @@ def run_one(
# ensure we always provide video_paths key to simplify accumulation
if max_episodes_rendered > 0:
metrics.setdefault("video_paths", [])
metrics.setdefault("predicted_video_paths", [])
return task_group, task_id, metrics
@@ -742,11 +819,11 @@ def eval_policy_all(
_append("sum_rewards", metrics.get("sum_rewards"))
_append("max_rewards", metrics.get("max_rewards"))
_append("successes", metrics.get("successes"))
# video_paths is list-like
paths = metrics.get("video_paths", [])
if paths:
group_acc[group]["video_paths"].extend(paths)
overall["video_paths"].extend(paths)
for key in ("video_paths", "predicted_video_paths"):
paths = metrics.get(key, [])
if paths:
group_acc[group][key].extend(paths)
overall[key].extend(paths)
# Choose runner (sequential vs threaded)
task_runner = partial(
@@ -814,6 +891,7 @@ def eval_policy_all(
"pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"),
"n_episodes": len(acc["sum_rewards"]),
"video_paths": list(acc["video_paths"]),
"predicted_video_paths": list(acc["predicted_video_paths"]),
}
# overall aggregates
@@ -825,6 +903,7 @@ def eval_policy_all(
"eval_s": time.time() - start_t,
"eval_ep_s": (time.time() - start_t) / max(1, len(overall["sum_rewards"])),
"video_paths": list(overall["video_paths"]),
"predicted_video_paths": list(overall["predicted_video_paths"]),
}
return {
+13
View File
@@ -25,6 +25,7 @@ Strategies
--strategy.type=sentry Continuous recording with auto-upload
--strategy.type=highlight Ring buffer + keystroke save
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
--strategy.type=episodic Episode-oriented recording with reset phases
Inference backends
------------------
@@ -111,6 +112,18 @@ Usage examples
--display_data=true \\
--use_torch_compile=true
# Episodic mode — episode-oriented recording with reset phases
lerobot-rollout \\
--strategy.type=episodic \\
--policy.path=user/my_policy \\
--robot.type=so100_follower \\
--robot.port=/dev/ttyACM0 \\
--teleop.type=so100_leader \\
--teleop.port=/dev/ttyACM1 \\
--dataset.repo_id=user/rollout_episodic_data \\
--dataset.num_episodes=20 \\
--dataset.single_task="Grab the cube"
# Resume a previous sentry recording session
lerobot-rollout \\
--strategy.type=sentry \\
+12 -17
View File
@@ -292,19 +292,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
active_cfg = cfg.trainable_config
processor_pretrained_path = active_cfg.pretrained_path
if (
getattr(active_cfg, "use_relative_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
logging.warning(
"use_relative_actions=true with pretrained processors can skip relative transforms if "
"the checkpoint processors do not define them. Building processors from current policy config."
)
processor_pretrained_path = None
processor_kwargs = {}
postprocessor_kwargs = {}
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
processor_kwargs["dataset_stats"] = dataset.meta.stats
@@ -312,24 +301,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
processor_kwargs["dataset_meta"] = dataset.meta
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
preprocessor_overrides = {
"device_processor": {"device": device.type},
"normalizer_processor": {
"stats": dataset.meta.stats,
"features": {**policy.config.input_features, **policy.config.output_features},
"norm_map": policy.config.normalization_mapping,
},
"rename_observations_processor": {"rename_map": cfg.rename_map},
}
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
"rename_map": cfg.rename_map
}
postprocessor_kwargs["postprocessor_overrides"] = {
postprocessor_overrides = {
"unnormalizer_processor": {
"stats": dataset.meta.stats,
"features": policy.config.output_features,
"norm_map": policy.config.normalization_mapping,
},
}
if getattr(active_cfg, "use_relative_actions", False):
preprocessor_overrides["relative_actions_processor"] = {
"enabled": True,
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
"action_names": getattr(active_cfg, "action_feature_names", None),
}
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
processor_kwargs["postprocessor_overrides"] = postprocessor_overrides
if cfg.is_reward_model_training:
preprocessor, postprocessor = make_reward_pre_post_processors(
@@ -341,7 +337,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
@@ -13,6 +13,8 @@
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
{% elif model_name == "sarm" %}
A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable.
{% elif model_name == "robometer" %}
ROBOMETER is a general-purpose video-language robotic reward model built on a fine-tuned Qwen3-VL-4B backbone with progress, preference, and success heads. Given a trajectory video and a task description, it predicts dense, frame-level task progress in [0, 1] and frame-level success probabilities for downstream robot learning, including offline RL, online RL, data filtering and retrieval, and automated failure detection.
{% elif model_name == "topreward" %}
TOPReward is a **zero-shot** reward model that extracts token log-probabilities from an off-the-shelf vision-language model (default Qwen3-VL) as a reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood of the instruction being true, with no fine-tuning required.
{% else %}
+36
View File
@@ -24,6 +24,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import datasets
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
@@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset):
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@pytest.mark.parametrize(
"dtype,np_dtype,values,assert_fn",
[
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
],
ids=["float32", "int64", "bool"],
)
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
):
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
captured = {}
original_from_dict = datasets.Dataset.from_dict
def _from_dict_spy(cls, mapping, *args, **kwargs):
captured["state"] = mapping["state"]
return original_from_dict(mapping, *args, **kwargs)
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
dataset.save_episode()
dataset.finalize()
assert "state" in captured
assert isinstance(captured["state"], np.ndarray)
assert captured["state"].shape == (2,)
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
def test_set_image_transforms_applies_transparently(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
+140
View File
@@ -0,0 +1,140 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ``lerobot.datasets.video_utils.VideoDecoderCache``.
These cover the LRU bounding + file-handle release behaviour added to prevent
unbounded growth when iterating over datasets with many distinct video files
(observed: ~35 GB anon-rss per DataLoader worker on an 8 k-file dataset).
"""
import shutil
from pathlib import Path
import pytest
pytest.importorskip("torchcodec", reason="torchcodec is required (install lerobot[dataset])")
from lerobot.datasets.video_utils import VideoDecoderCache # noqa: E402
TEST_ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts" / "encoded_videos"
SRC_CLIP = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
def _make_distinct_clips(tmp_path: Path, n: int) -> list[Path]:
"""Copy the small reference mp4 to ``n`` distinct paths.
The cache keys on absolute path, so distinct paths force distinct cache entries
even though the file contents are identical.
"""
assert SRC_CLIP.exists(), f"missing test artifact {SRC_CLIP}"
paths = []
for i in range(n):
dst = tmp_path / f"clip_{i:04d}.mp4"
shutil.copyfile(SRC_CLIP, dst)
paths.append(dst)
return paths
class TestVideoDecoderCacheBounded:
def test_default_cache_is_bounded(self):
"""The default cache must have a finite ``max_size`` to bound RSS growth."""
cache = VideoDecoderCache()
assert cache.max_size is not None, "default cache must be bounded"
assert cache.max_size > 0
def test_size_capped_at_max_size(self, tmp_path):
"""``get_decoder`` for >``max_size`` distinct paths must NOT grow without bound."""
paths = _make_distinct_clips(tmp_path, n=5)
cache = VideoDecoderCache(max_size=2)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 2
def test_evicts_least_recently_used(self, tmp_path):
"""Re-accessing an entry must promote it; the LRU entry is the one evicted."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=2)
cache.get_decoder(paths[0])
cache.get_decoder(paths[1])
cache.get_decoder(paths[0]) # promote paths[0] to MRU; paths[1] is now LRU
cache.get_decoder(paths[2]) # should evict paths[1]
assert str(paths[0]) in cache # MRU stays
assert str(paths[1]) not in cache # LRU evicted
assert str(paths[2]) in cache # newest stays
def test_eviction_closes_file_handle(self, tmp_path):
"""Evicting an entry must close its fsspec file handle (otherwise we leak FDs)."""
paths = _make_distinct_clips(tmp_path, n=2)
cache = VideoDecoderCache(max_size=1)
cache.get_decoder(paths[0])
# Reach into the cache to capture the handle before it is evicted. This is
# the only assertion in the suite that touches a private attribute, and it
# is the most direct way to prove the file descriptor is actually released.
evicted_handle = cache._cache[str(paths[0])][1]
assert evicted_handle.closed is False
cache.get_decoder(paths[1]) # forces eviction of paths[0]
assert evicted_handle.closed is True
def test_clear_closes_all_file_handles(self, tmp_path):
"""``clear()`` must close every cached file handle."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=10)
for p in paths:
cache.get_decoder(p)
handles = [entry[1] for entry in cache._cache.values()]
assert all(not h.closed for h in handles)
cache.clear()
assert cache.size() == 0
assert all(h.closed for h in handles)
def test_hit_does_not_reopen_or_evict(self, tmp_path):
"""A cache hit must return the same decoder instance without touching the cap."""
paths = _make_distinct_clips(tmp_path, n=1)
cache = VideoDecoderCache(max_size=2)
first = cache.get_decoder(paths[0])
second = cache.get_decoder(paths[0])
assert first is second
assert cache.size() == 1
def test_unbounded_when_max_size_none(self, tmp_path):
"""``max_size=None`` preserves the legacy unbounded behaviour."""
paths = _make_distinct_clips(tmp_path, n=4)
cache = VideoDecoderCache(max_size=None)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 4
def test_env_var_overrides_default(self, tmp_path, monkeypatch):
"""``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var sets the default ``max_size``."""
monkeypatch.setenv("LEROBOT_VIDEO_DECODER_CACHE_SIZE", "3")
cache = VideoDecoderCache()
assert cache.max_size == 3
paths = _make_distinct_clips(tmp_path, n=5)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 3
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,78 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES
def make_config(**overrides) -> LingBotVAConfig:
kwargs = {"device": "cpu"}
kwargs.update(overrides)
return LingBotVAConfig(**kwargs)
def test_registered_in_choice_registry() -> None:
assert "lingbot_va" in PreTrainedConfig.get_known_choices()
assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig
def test_type_property() -> None:
assert make_config().type == "lingbot_va"
def test_chunk_size_and_action_steps() -> None:
cfg = make_config(frame_chunk_size=4, action_per_frame=4)
assert cfg.chunk_size == 16
assert cfg.n_action_steps == 16
assert cfg.action_delta_indices == list(range(16))
assert cfg.observation_delta_indices is None
assert cfg.reward_delta_indices is None
def test_optimizer_and_scheduler_presets() -> None:
cfg = make_config()
opt = cfg.get_optimizer_preset()
assert opt.lr == cfg.optimizer_lr
sched = cfg.get_scheduler_preset()
assert sched.num_warmup_steps == cfg.scheduler_warmup_steps
def test_validate_features_sets_action_feature() -> None:
cfg = make_config()
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
cfg.output_features = {}
cfg.validate_features()
assert ACTION in cfg.output_features
assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),)
def test_validate_features_no_visual_raises() -> None:
cfg = make_config()
cfg.input_features = {}
cfg.output_features = {}
with pytest.raises(ValueError, match="at least one visual input feature"):
cfg.validate_features()
def test_invalid_attn_mode_raises() -> None:
with pytest.raises(ValueError, match="attn_mode"):
make_config(attn_mode="banana")
+38
View File
@@ -0,0 +1,38 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
from lerobot.policies.factory import make_policy_config
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
def test_make_policy_config_returns_lingbot_va() -> None:
cfg = make_policy_config("lingbot_va", device="cpu")
assert isinstance(cfg, LingBotVAConfig)
def test_get_policy_class_resolves_lazily() -> None:
# Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable.
pytest.importorskip("diffusers")
pytest.importorskip("transformers")
from lerobot.policies.factory import get_policy_class
cls = get_policy_class("lingbot_va")
assert cls.name == "lingbot_va"
assert cls.config_class is LingBotVAConfig
+131
View File
@@ -0,0 +1,131 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the vendored LingBot-VA helper code (scheduler + grid utilities)."""
from __future__ import annotations
import pytest
import torch
pytest.importorskip("diffusers") # the model code lives in modeling_lingbot_va, which imports diffusers
from lerobot.policies.lingbot_va.modeling_lingbot_va import ( # noqa: E402
FlowMatchScheduler,
data_seq_to_patch,
get_mesh_id,
)
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
assert sch.timesteps.shape == (20,)
diffs = sch.timesteps[1:] - sch.timesteps[:-1]
assert torch.all(diffs <= 0) # decreasing
def test_flow_match_scheduler_step_preserves_shape() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
sample = torch.zeros(1, 48, 4, 8, 16)
out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample)
assert out.shape == sample.shape
def test_flow_match_scheduler_add_noise() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
sample = torch.randn(1, 48, 4, 8, 16)
noise = torch.randn_like(sample)
noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2)
assert noisy.shape == sample.shape
def test_get_mesh_id_latent_shape() -> None:
grid = get_mesh_id(4, 8, 16, 0, 1, 0)
assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens
def test_get_mesh_id_action_shape() -> None:
grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True)
assert grid.shape == (4, 4 * 4 * 1)
# Action rows for h/w are sentinel -1.
assert torch.all(grid[1] < 0)
assert torch.all(grid[2] < 0)
def test_data_seq_to_patch_roundtrip_shape() -> None:
b, f, h, w, c = 1, 4, 8, 16, 48
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
assert out.shape == (b, c, f, h, w)
def test_training_step_reduces_loss_tiny_flex() -> None:
"""End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config.
Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support.
"""
if not torch.cuda.is_available():
import pytest
pytest.skip("training step test requires a CUDA GPU (flex-attention)")
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES
cfg = LingBotVAConfig(
attn_mode="flex",
dtype="bfloat16",
in_channels=16,
out_channels=16,
action_dim=8,
text_dim=32,
freq_dim=64,
ffn_dim=64,
num_attention_heads=2,
attention_head_dim=24,
num_layers=2,
frame_chunk_size=2,
action_per_frame=4,
used_action_channel_ids=[0, 1, 2, 3],
obs_cam_keys=[f"{OBS_IMAGES}.image"],
device="cuda",
)
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))}
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))}
cfg.validate_features()
policy = LingBotVAPolicy(cfg).to("cuda")
policy.train()
opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4)
b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame
latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16)
actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16)
amask = torch.zeros(cfg.action_dim, device="cuda")
amask[cfg.used_action_channel_ids] = 1.0
actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions)
text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16)
loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb)
assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics)
loss.backward()
assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params())
opt.step()
@@ -0,0 +1,88 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline, UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def _make_config() -> LingBotVAConfig:
cfg = LingBotVAConfig(device="cpu")
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
cfg.output_features = {}
cfg.validate_features()
return cfg
def test_make_pre_post_processors_names_and_steps() -> None:
cfg = _make_config()
pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME
assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME
# Actions are unnormalized by the standard built-in quantile unnormalizer.
assert any(isinstance(s, UnnormalizerProcessorStep) for s in post.steps)
def test_freshly_built_postprocessor_is_identity() -> None:
# Without action stats the quantile unnormalizer is a no-op (identity passthrough): the real
# per-benchmark q01/q99 are restored from the saved checkpoint on load, not hardcoded here.
cfg = _make_config()
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
normed = torch.tensor([[0.3, -0.5, 1.0, -1.0, 0.0, 0.7, -0.2]])
assert torch.allclose(post(normed), normed, atol=1e-6)
def test_postprocessor_quantile_unnormalization() -> None:
# QUANTILES unnormalize maps [-1, 1] -> [q01, q99]: -1 -> q01, +1 -> q99.
cfg = _make_config()
q01 = [-1.0, -0.5, 0.0, -1.0, -1.0, -1.0, -1.0]
q99 = [1.0, 0.5, 2.0, 1.0, 1.0, 1.0, 1.0]
stats = {ACTION: {"q01": q01, "q99": q99}}
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=stats)
out_lo = post(torch.full((1, 7), -1.0))
out_hi = post(torch.full((1, 7), 1.0))
assert torch.allclose(out_lo, torch.tensor(q01).unsqueeze(0), atol=1e-4)
assert torch.allclose(out_hi, torch.tensor(q99).unsqueeze(0), atol=1e-4)
def test_postprocessor_stats_survive_save_load(tmp_path) -> None:
# Regression guard for the Hub mechanism: the q01/q99 stats live in the saved post-processor
# state and must round-trip through save_pretrained / from_pretrained.
cfg = _make_config()
q01 = [-0.6, -0.8, -0.9, -0.1, -0.15, -0.25, -1.0]
q99 = [0.9, 0.85, 0.9, 0.17, 0.18, 0.34, 1.0]
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats={ACTION: {"q01": q01, "q99": q99}})
post.save_pretrained(tmp_path)
loaded = PolicyProcessorPipeline.from_pretrained(
tmp_path,
config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
out = loaded(torch.full((1, 7), -1.0))
assert torch.allclose(out, torch.tensor(q01).unsqueeze(0), atol=1e-4)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1 @@
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""
@@ -0,0 +1,22 @@
from dataclasses import dataclass
@dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
def get_config(variant: str) -> Config:
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
if variant == "dummy":
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
if variant == "gemma_300m":
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
if variant == "gemma_2b":
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
raise ValueError(f"Unknown variant: {variant}")
@@ -0,0 +1,300 @@
from typing import Literal
import torch
from torch import nn
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
class PaliGemmaWithExpertModel(nn.Module):
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Force enable gradient checkpointing if we're in training mode and the model supports it
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
if not self.gemma_expert.model.gradient_checkpointing:
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
self.gemma_expert.model.gradient_checkpointing = True
use_gradient_checkpointing = True
# Debug gradient checkpointing status
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
print(f"Model training mode: {self.training}")
print(
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
)
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
print(
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
)
self._debug_gc_printed = True
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
models = [self.paligemma.model.language_model, self.gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(
layer.input_layernorm, hidden_states, adarms_cond[i]
)
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
self.paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
)
# Old code removed - now using compute_layer_complete function above
# final norm
# Define final norm computation function for gradient checkpointing
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms,
inputs_embeds,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values
@@ -0,0 +1,79 @@
import torch
import torch.nn.functional as F # noqa: N812
def resize_with_pad_torch(
images: torch.Tensor,
height: int,
width: int,
mode: str = "bilinear",
) -> torch.Tensor:
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
by padding with black. If the image is float32, it must be in the range [-1, 1].
Args:
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
height: Target height
width: Target width
mode: Interpolation mode ('bilinear', 'nearest', etc.)
Returns:
Resized and padded tensor with same shape format as input
"""
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
if images.shape[-1] <= 4: # Assume channels-last format
channels_last = True
# Convert to channels-first for torch operations
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
else:
channels_last = False
if images.dim() == 3:
images = images.unsqueeze(0) # Add batch dimension
batch_size, channels, cur_height, cur_width = images.shape
# Calculate resize ratio
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
# Resize
resized_images = F.interpolate(
images,
size=(resized_height, resized_width),
mode=mode,
align_corners=False if mode == "bilinear" else None,
)
# Handle dtype-specific clipping
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
# Calculate padding
pad_h0, remainder_h = divmod(height - resized_height, 2)
pad_h1 = pad_h0 + remainder_h
pad_w0, remainder_w = divmod(width - resized_width, 2)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
mode="constant",
value=constant_value,
)
# Convert back to original format if needed
if channels_last:
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
if batch_size == 1 and images.shape[0] == 1:
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
return padded_images
@@ -0,0 +1,471 @@
import copy
import logging
import math
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device):
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
return att_2d_masks & pad_2d_masks
class PI0Pytorch(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pi05 = config.pi05
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_config,
action_expert_config,
use_adarms=[False, True] if self.pi05 else [False, False],
precision=config.dtype,
)
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
if self.pi05:
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
else:
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
torch.set_float32_matmul_precision("high")
if config.pytorch_compile_mode is not None:
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
# The upstream OpenPI module verifies a site-package Transformers patch here.
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def is_gradient_checkpointing_enabled(self):
"""Check if gradient checkpointing is enabled."""
return self.gradient_checkpointing_enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
return torch.utils.checkpoint.checkpoint(
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
def _preprocess_observation(self, observation, *, train=True):
"""Helper method to preprocess observation."""
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
return (
list(observation.images.values()),
list(observation.image_masks.values()),
observation.tokenized_prompt,
observation.tokenized_prompt_mask,
observation.state,
)
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
embs = []
pad_masks = []
att_masks = []
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
# Process language tokens
def lang_embed_func(lang_tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
# Get batch size from the first dimension of the concatenated tensors
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
if not self.pi05:
if self.state_proj.weight.dtype == torch.float32:
state = state.to(torch.float32)
# Embed state
def state_proj_func(state):
return self.state_proj(state)
state_emb = self._apply_checkpoint(state_proj_func, state)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep,
self.action_in_proj.out_features,
min_period=4e-3,
max_period=4.0,
device=timestep.device,
)
time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
def action_proj_func(noisy_actions):
return self.action_in_proj(noisy_actions)
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
if not self.pi05:
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
# Apply MLP layers
def mlp_func(action_time_emb):
x = self.action_time_mlp_in(action_time_emb)
x = F.silu(x) # swish == silu
return self.action_time_mlp_out(x)
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
adarms_cond = None
else:
# time MLP (for adaRMS)
def time_mlp_func(time_emb):
x = self.time_mlp_in(time_emb)
x = F.silu(x) # swish == silu
x = self.time_mlp_out(x)
return F.silu(x)
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
action_time_emb = action_emb
adarms_cond = time_emb
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks, adarms_cond
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=True
)
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
# Prepare attention masks
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
# Apply gradient checkpointing if enabled
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
# Apply gradient checkpointing to final action projection if enabled
def action_out_proj_func(suffix_out):
return self.action_out_proj(suffix_out)
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad()
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = observation.state.shape[0]
if noise is None:
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
noise = self.sample_noise(actions_shape, device)
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
observation, train=False
)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=True,
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step - use new tensor assignment instead of in-place operation
x_t = x_t + dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# Prepare attention masks
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)
@@ -0,0 +1,179 @@
import logging
from collections.abc import Sequence
import torch
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
logger = logging.getLogger("openpi")
# Constants moved from model.py
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
IMAGE_RESOLUTION = (224, 224)
def preprocess_observation_pytorch(
observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
):
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
This function avoids complex type annotations that can cause torch.compile issues.
"""
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
# Handle both [B, C, H, W] and [B, H, W, C] formats
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
if is_channels_first:
# Convert [B, C, H, W] to [B, H, W, C] for processing
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad_torch(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
image = image / 2.0 + 0.5
# Apply PyTorch-based augmentations
if "wrist" not in key:
# Geometric augmentations for non-wrist cameras
height, width = image.shape[1:3]
# Random crop and resize
crop_height = int(height * 0.95)
crop_width = int(width * 0.95)
# Random crop
max_h = height - crop_height
max_w = width - crop_width
if max_h > 0 and max_w > 0:
# Use tensor operations instead of .item() for torch.compile compatibility
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
# Resize back to original size
image = torch.nn.functional.interpolate(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
size=(height, width),
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Random rotation (small angles)
# Use tensor operations instead of .item() for torch.compile compatibility
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
# Convert to radians
angle_rad = angle * torch.pi / 180.0
# Create rotation matrix
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# Apply rotation using grid_sample
grid_x = torch.linspace(-1, 1, width, device=image.device)
grid_y = torch.linspace(-1, 1, height, device=image.device)
# Create meshgrid
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
# Expand to batch dimension
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
# Apply rotation transformation
grid_x_rot = grid_x * cos_a - grid_y * sin_a
grid_y_rot = grid_x * sin_a + grid_y * cos_a
# Stack and reshape for grid_sample
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
image = torch.nn.functional.grid_sample(
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
grid,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
# Color augmentations for all cameras
# Random brightness
# Use tensor operations instead of .item() for torch.compile compatibility
brightness_factor = (
0.7 + torch.rand(1, device=image.device) * 0.6
) # Random factor between 0.7 and 1.3
image = image * brightness_factor
# Random contrast
# Use tensor operations instead of .item() for torch.compile compatibility
contrast_factor = (
0.6 + torch.rand(1, device=image.device) * 0.8
) # Random factor between 0.6 and 1.4
mean = image.mean(dim=[1, 2, 3], keepdim=True)
image = (image - mean) * contrast_factor + mean
# Random saturation (convert to HSV, modify S, convert back)
# For simplicity, we'll just apply a random scaling to the color channels
# Use tensor operations instead of .item() for torch.compile compatibility
saturation_factor = (
0.5 + torch.rand(1, device=image.device) * 1.0
) # Random factor between 0.5 and 1.5
gray = image.mean(dim=-1, keepdim=True)
image = gray + (image - gray) * saturation_factor
# Clamp values to [0, 1]
image = torch.clamp(image, 0, 1)
# Back to [-1, 1]
image = image * 2.0 - 1.0
# Convert back to [B, C, H, W] format if it was originally channels-first
if is_channels_first:
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
else:
out_masks[key] = observation.image_masks[key]
# Create a simple object with the required attributes instead of using the complex Observation class
class SimpleProcessedObservation:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return SimpleProcessedObservation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
token_ar_mask=observation.token_ar_mask,
token_loss_mask=observation.token_loss_mask,
)
@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.policies.pi05 import PI05Config # noqa: E402
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
assert_cache_stability,
assert_compiled_output_matches_eager,
assert_explain_has_no_graph_breaks,
benchmark_runtime,
make_compile_config,
reset_compile_state,
)
from tests.utils import require_cuda # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
)
def _make_model(*, compile_model):
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
def _make_dummy_inputs(config):
device = torch.device("cuda")
common = {
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
}
forward_kwargs = {
**common,
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"time": torch.rand(1, device=device),
}
sample_kwargs = {
**common,
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
"num_steps": config.num_inference_steps,
}
return forward_kwargs, sample_kwargs
@require_cuda
def test_pi05_torch_compile_forward_and_sample_actions():
if not hasattr(torch, "compile"):
pytest.skip("torch.compile is not available")
if not torch._dynamo.is_dynamo_supported():
pytest.skip("torch._dynamo is not supported on this platform")
torch.manual_seed(0)
eager_model = _make_model(compile_model=False)
torch.manual_seed(0)
compiled_model = _make_model(compile_model=True)
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
try:
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
benchmark_runtime(
eager_model.sample_actions,
compiled_model.sample_actions,
sample_kwargs,
"pi05.sample_actions",
)
finally:
reset_compile_state()
del eager_model
del compiled_model
torch.cuda.empty_cache()
@@ -14,52 +14,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
import gc
import os
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
from lerobot.configs import PreTrainedConfig # noqa: E402
from lerobot.policies.pi05 import PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
assert_processor_inputs_match_lerobot,
clone_batch,
deterministic_openpi_forward_preprocess,
fix_reference_state_dict,
fixed_flow_sampling,
load_openpi_reference_state_dict,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
)
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
from lerobot.types import PolicyAction # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COMPILE_MODE = "default"
FORWARD_RTOL = 1e-4
FORWARD_ATOL = 1e-4
SAMPLE_RTOL = 1e-2
SAMPLE_ATOL = 5e-3
DUMMY_DATASET_STATS = {
"observation.state": {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
"action": {
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
@@ -88,6 +92,15 @@ DUMMY_DATASET_STATS = {
}
@pytest.fixture(autouse=True)
def cleanup_cuda_after_test():
yield
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class PI05BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
@@ -96,341 +109,163 @@ class PI05BaseOriginalConfig:
precision: str = "float32"
pi05: bool = True
dtype: str = "float32"
pytorch_compile_mode: str | None = None
def instantiate_lerobot_pi05(
from_pretrained: bool = False,
) -> tuple[
PI05Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
else:
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI05Policy(config)
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
config.device = str(DEVICE)
config.dtype = "float32"
config.compile_model = compile_model
config.compile_mode = COMPILE_MODE
config.gradient_checkpointing = gradient_checkpointing
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi05_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
policy.config.device = str(DEVICE)
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
return policy, preprocessor
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
config = PI05BaseOriginalConfig()
policy = PI0Pytorch(config)
def instantiate_original_pi05():
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名;LeRobot 的
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
# loader,所以这里手动复用同一份 tensor,避免把权重别名差异误判成模型差异。
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
assert missing_keys == []
assert unexpected_keys == []
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
torch.manual_seed(0)
raw_batch = create_dummy_data()
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
openpi_actions = openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
assert_processor_inputs_match_lerobot(
lerobot_pi05,
lerobot_batch,
openpi_observation,
compare_state=False,
)
batch_size = raw_batch[OBS_STATE].shape[0]
noise = torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=DEVICE,
)
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
return lerobot_batch, openpi_observation, openpi_actions, noise, time
class PI05Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description (pi05 processor handles all text formatting)
tasks = batch.get("task", ["Pick up the object"] * batch_size)
if isinstance(tasks, str):
tasks = [tasks] * batch_size
elif len(tasks) == 1:
tasks = tasks * batch_size
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
state = batch["observation.state"]
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
from lerobot.policies.pi05.modeling_pi05 import pad_vector
state = pad_vector(state, DUMMY_STATE_DIM)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create pi05-formatted prompts that include state information
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
full_prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
compile_model=compile_model,
gradient_checkpointing=gradient_checkpointing,
)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
if gradient_checkpointing:
lerobot_pi05.train()
else:
lerobot_pi05.eval()
original_pi05.eval()
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
with deterministic_openpi_forward_preprocess(original_pi05):
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
openpi_loss = openpi_losses.mean(dim=(1, 2))
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
original_pi05 = instantiate_original_pi05()
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
lerobot_pi05,
lerobot_preprocessor,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi05_original_vs_lerobot():
"""Test PI05 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi05(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...")
lerobot_pi05.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
original_pi05.eval()
with torch.no_grad():
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
openpi_actions = original_pi05.sample_actions(
device=DEVICE,
observation=openpi_observation,
noise=noise,
num_steps=10,
)
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
def test_pi05_forward_matches_openpi():
assert_forward_matches()
def test_pi05_sample_actions_match_openpi():
assert_sample_actions_match_openpi()
def test_pi05_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(gradient_checkpointing=True)
def test_pi05_compile_forward_matches_openpi():
assert_forward_matches(compile_model=True)
def test_pi05_compile_sample_actions_match_openpi():
assert_sample_actions_match_openpi(compile_model=True)
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
assert_forward_matches(compile_model=True, gradient_checkpointing=True)

Some files were not shown because too many files have changed in this diff Show More