Compare commits

..

168 Commits

Author SHA1 Message Date
Bryson Jones 3dbf463f3b Merge branch 'main' into feature/add-multitask-dit 2026-03-06 08:38:19 -08:00
Shun.Sasaki 6139b133ca fix(async_inference): restore robot module imports in robot_client.py (#3081) 2026-03-06 17:14:14 +01:00
Bryson Jones bc7740a15d Merge branch 'main' into feature/add-multitask-dit 2026-03-06 07:56:58 -08:00
Steven Palma 85de893fa7 fix(ci): skip HF log in (and tests) in forks and community PRs (#3097)
* fix(ci): skip HF log in (and tests) in forks and community PRs

* chore(test): remove comment about test meant to be only run locally

* fix(tests): no hf log in decorator for xvla

* fix(test): no decorator in yield
2026-03-06 16:33:43 +01:00
Bryson Jones 03e009db78 Merge branch 'main' into feature/add-multitask-dit 2026-03-06 07:22:23 -08:00
Steven Palma a4c66e530b chore(docs): remove pi installation note (#3095) 2026-03-06 15:52:54 +01:00
Steven Palma 0794b4ba8f Merge branch 'main' into feature/add-multitask-dit 2026-03-06 14:25:15 +01:00
Steven Palma a225127527 chore(dependencies): sync intelrealsense + added notes (#3094) 2026-03-06 10:50:46 +01:00
Steven Palma e489ba24fc feat(dependencies): require Python 3.12+ as minimum version (#3023)
* feat(dependecies): upgrade to python3.12

* fix(test): processor regex message

* fix(test): processor regex message

* fix(dependecies): resolve all tags in python 3.12

* fix(dependecies): add more hints to faster resolve

* chore(dependecies): remove cli tag huggingface-hub dep

* refactor(policy): update eagle for python3.12

* chore(docs): update policy creation for python 3.12

* chore(test): skip failing tests in macos
2026-03-06 10:15:13 +01:00
Bryson Jones eb85d2d541 Merge branch 'main' into feature/add-multitask-dit 2026-03-05 20:21:36 -08:00
Steven Palma d324ffe810 fix(ci): test only multi-gpu tests in multi-gpu runner (#3092) 2026-03-05 19:53:40 +01:00
Bryson Jones d7537c85c5 Merge branch 'main' into feature/add-multitask-dit 2026-03-05 10:06:54 -08:00
Pepijn 1a24f770d3 Feat/slurm compute rabc script (#3041)
* Add SLURM SARM progress annotation script.

Provide a standalone two-stage compute/aggregate pipeline for RA-BC progress generation so large datasets can be processed in parallel and optionally uploaded to the Hub.

Made-with: Cursor

* fix pr comments

* remove comments
2026-03-05 18:27:58 +01:00
Bryson Jones b2a32b8076 Merge remote-tracking branch 'upstream/main' into feature/add-multitask-dit
# Conflicts:
#	pyproject.toml
#	src/lerobot/policies/pi_gemma.py
#	tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py
#	tests/policies/pi0_pi05/test_pi0.py
#	tests/policies/pi0_pi05/test_pi05.py
2026-03-05 07:32:09 -08:00
Caroline Pascal 92fba37225 fix(num_frames): fixing redundant frames count in conversion script (#3091) 2026-03-05 15:49:50 +01:00
Steven Palma 3e45120272 fix(ci): log in HF for gated repo in nightly workflows (#3089)
* fix(ci): log in HF for gated repo in nightly workflows

* fix(ci): add env var

* fix(ci): remove 10 min limit for multi-gpu nightly
2026-03-05 13:22:37 +01:00
Steven Palma f0d2b37beb chore(dependencies): bump transformers v5 (#2964)
* chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy

* chore(dependencies): bump pi0 family to transformers v5

* chore(dependencies): bump wall x to transformers v5

* chore(dependencies): bump gr00t to transformers v5

* chore(style): fix pre-commit

* fix(policy): xvla forced_bos_token missing

* test(rl): skip ci tests for resnet10

* Fix: full pi models support for transformer v5 (#2967)

* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>

* chore: fix XVLA in transformers v5 (#3006)

* test(policies): enable wall x CI testing

* style(test): pre-commit check

* style(test): pre-commit

* fix wall x for transformer v5 (#3008)

* tv5 fix

* various wall x fixes

* Delete tests/policies/pi0_pi05/print_pi05_output_logits.py

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* sync modeling_florence2.py with chore/bump_transformers_v5

* more

* more fixes

* more

* remove comment

* more

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* chore(dependencies): adjust dependencies versioning after transformers v5 (#3034)

* chore(dependecies): adjust dependecies versioning after transformers v5

* fix(policies): remove deprecated input_embeds

* fix(policies): dict _tied_weights_keys

* chore(depedencies): common qwen-vl-utils

* chore(dependencies): bump transformers to 5.2

* Fix policy testing for tv5 (#3032)

* fix ci logger

* other fix

* fix mypy

* change logits to torch2.10

* skip wallx|

* remove logging

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>

* feat(ci): log into HF to unblock some CI tests (#3007)

* feat(ci): log into HF to unblock some CI tests

* chore(ci): change hf call + secret name

* fix(ci): temp fix for pi0 rtc test

* test(policies): require_cuda for unblocked tests

* test(policies): require_cuda wall_x

* fic(tests): require_cuda outter most for pi0

* fix(test): return instead of yield

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* style(test): fix pre-commit

* chore(deps): upgrade transformers (#3050)

* chore(test): use lerobot model

* fix(policies): change default action tokenizer for wall x

* sample on cpu

* Revert "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5"

This reverts commit d9b76755f7, reversing
changes made to 89359cb0b6.

* Reapply "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5"

This reverts commit c9914db78b.

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Yufei Sun <skieyfly@gmail.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2026-03-05 09:25:26 +01:00
Caroline Pascal cbc8bfb2e6 chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label (#3077)
* chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label

* chore(task): renamming the default index label in the tasks DataFrame to task

* Revert "chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label"

This reverts commit f55de3255278f23f18b5d955565f6768d094951d.

* chore(docstrings): updating docstrings to match dataset v3.0 architecture

* chore(format): formatting code
2026-03-04 17:59:03 +01:00
Paul Crook 0d1be72dc8 Fixing metadata indexing when writing new Parquet file (#2941)
* Fixing metadata indexing when writing new Parquet file

Summary:
  - addressing this issue: https://github.com/huggingface/lerobot/issues/2401
  - vibe-coded bugfix by Claude Sonnet 4.5

* Backing out changes to convert_videos_of_camera

* Addressing Ruff pre-commit complaint

Summary:
 - addressing "SIM113 Use `enumerate()` for index variable `ep_idx` in `for` loop"

---------

Co-authored-by: Paul <238953601+pac-robotics@users.noreply.github.com>
2026-03-04 16:53:34 +01:00
Maxime Ellerbach 96b7c212c4 chore(docs): updating deprecated huggingface-cli to hf (#3071)
* chore(docs): updating deprecated huggingface-cli to hf

* small typo in my-org
2026-03-04 15:08:49 +01:00
Caroline Pascal 4303b3c930 chore(root): fixing root semantics in convert_dataset script (#3073)
* fix(root): fixing root semantincs in convert_dataset script

* fix(\): fixing command syntax in dataset conversion script

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

---------

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-03-04 11:11:21 +01:00
Caroline Pascal 63dca86df8 fix(dataset edit tools): clarifying root argument usage + adding related features (#3049)
* fix(root): adding proper support for the root and new_root arguments

* feat(roots): adding a roots agrument for the merge operation

* chore(clean): cleaning up code

* chore(doctrings): updating doctrings with new features

* fix(repo_id): setting repo_id to None when not needed

* fix(roots/repo_ids): making mypy happy by using repo_ids and roots for merge operation

* fix(path): fixing path related issues

* fix(repo_id): fixing issues related to repo_id

* chore(doctrings): updating docstrings + fix typo

* chore(clean): cleaning code

* fix(split new_repo_id): reverting new_repo_id addition for split operation

* docs(dosctrings): completing docstrings

* fix(repo_ids/roots): improving checks for repo_ids/roots lengths

* fix(repo_ids): making repo_ids optional in MergeConfig but raise if not given

* fix(docstrings): fixing docstrings for split operation

* fix(hints): updating get_output_path hints to accept paths as strings too

* fix(y/N prompts): removing y/N prompts in lerobot_edit_dataset

* fix(merge repo_id): fixing merge operation to use new_repo_id instead of repo_id

* fix(typo): fixing typo in doctrings
2026-03-03 15:40:46 +01:00
Caroline Pascal 8a0cc3d664 fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets (#3068)
* fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets

* fix(segfault risk): removing segfault risk by calling  batch["index"] in the dataloader loop
2026-03-03 11:55:09 +01:00
Bernie Telles 8bb8ed4803 Improve policy_device documentation for async.mdx (#3060) 2026-03-02 15:35:15 +01:00
Steven Palma 095856b06a chore: add AI policy (#3055) 2026-02-28 14:41:28 +01:00
Steven Palma 563f42bdb1 chore(dependencies): Bump lerobot to 0.4.5 (#3051) 2026-02-27 19:29:35 +01:00
Caroline Pascal 8fff0fde7c chore(docstrings): fixing deprecated root argument description in LeRobotDataset class (#3035)
* chore(docstrings): fixing deprecated `root` argument docstrings in LeRobotDataset class

* chore(draccus): updating draccus CLI help

* chore(revert): reverting changes in lerobot_dataset_viz.py

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:22:44 +01:00
Pepijn 04de496547 fix(logging): avoid double-counting samples across processes (#3045) 2026-02-27 17:45:19 +01:00
Khalil Meftah baf9b50365 Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046)
* refactor(diffusion): replace crop_shape with resize_shape and crop_ratio

* fix(diffusion): address review feedback on resize/crop backward compat

* test: regenerate diffusion artifacts for updated default config

* fix: disable crop when resize path uses crop_ratio=1.0

---------

Co-authored-by: starlitxiling <1754165401@qq.com>
2026-02-27 17:44:53 +01:00
Jade Choghari a0fdbf037a feat(policies): add Smolvla torch compile support (#3043)
* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* pre-commit run

* Add torch.compile for smolvla

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Add torch.compile for smolvla

Add model compilation option for improved performance.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* first

---------

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:58:36 +03:00
Khalil Meftah c085531b17 fix: add missing openarm_mini import to CLI scripts (#3028) 2026-02-27 15:46:31 +01:00
Steven Palma c7c6205332 chore(scripts): no spam log when no action (#3042) 2026-02-27 15:26:56 +01:00
Michio Sun 4e54be1334 fix(datasets): skip warning when MultiLeRobotDataset features are identical (#3019)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 17:42:22 +01:00
Damien LaRocque fde9d08281 feat(async_inference) Enable plugins with async inference (#2425)
* feat(async-inference) Try using async inference server with plugins

* Fix import

* Fix import error in Robot Client

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 14:41:32 +01:00
Khalil Meftah 46044fed75 Fix: remove device_map from SmolVLA model loading (#3029)
* Fix SmolVLA meta tensor error by removing device_map

- Remove device_map parameter from VLM model loading
- Change torch_dtype from string to torch.bfloat16
- Add explicit .to(device) calls after initialization

This resolves NotImplementedError when training SmolVLA policy.
Fixes meta tensor copy issue in factory.py:418.

* fix: remove manual device movement logic and fix dtype handling

---------

Co-authored-by: Highsky7 <albert31115@gmail.com>
2026-02-26 13:28:46 +01:00
Khalil Meftah 975dcad918 Feat(teleoperators): add OpenArm Mini teleoperator (#3022)
* add OpenArm Mini config and module init

* add OpenArm Mini teleoperator implementation

* add OpenArm Mini into factory and setup motors

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-02-25 18:46:55 +01:00
Cotton Hu d0b58190da fix(policies): support dp train when n_obs_steps=1 (#2430)
Co-authored-by: hukongtao <hukongtao@agibot.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-25 17:36:31 +01:00
Mishig 9a5ab8ffab feat: add visualization badge to card template and update dataset card creation with repo_id (#3005)
* feat: add visualization badge to card template and update dataset card creation with repo_id

* Update src/lerobot/datasets/card_template.md

* Update src/lerobot/datasets/card_template.md

---------

Signed-off-by: Mishig <dmishig@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-25 16:02:40 +01:00
Khalil Meftah 7541d72130 Fix SARM dense_only mode: always load episodes_df for target computation (#3021)
* fix annotation mode check

* fix: SARM dense_only mode always load episodes_df for target computation

---------

Co-authored-by: John Newsom <jackmnewsom@gmail.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-02-25 13:28:01 +01:00
Jash Shah 0317a15bf1 fix(video): replace assertions with proper exceptions in video frame decoding (#3016)
Replaced assert statements with FrameTimestampError exceptions in
decode_video_frames_torchvision and decode_video_frames_torchcodec.

Assertions are unsuitable for runtime validation because they can be
silently disabled with python -O, and they produce unhelpful
AssertionError tracebacks. The codebase already defines
FrameTimestampError for this exact purpose but it was only used
in one of the three validation sites.

Also removed AssertionError from the except clause in
LeRobotDataset.__init__, which was masking video timestamp errors
by silently triggering a dataset re-download instead of surfacing
the actual problem.
2026-02-25 12:29:22 +01:00
Jash Shah f138e5948a Fix metaworld_config.json not bundled in pip installs and AttributeError crash (#3017)
1. Include metaworld_config.json in package distributions by adding it to
   both MANIFEST.in (for sdist) and pyproject.toml package-data (for wheels).
   Without this, pip-installed lerobot raises FileNotFoundError when
   importing the metaworld environment.

2. Fix crash in sanity_check_dataset_name where the error message accesses
   policy_cfg.type when policy_cfg is None, raising AttributeError instead
   of the intended ValueError.

Fixes #2958
2026-02-25 12:29:10 +01:00
Martin Kiefel 8fef4ddab8 fix(dataset): Fix reindexing bug for videos on splits (#2548)
* fix(dataset): Reindex videos based on frame and not on time

Sometimes during split operations the frame timestamp floating
precision leads to frame ending up in the wrong split.

This changes fixes the issues by directly working with frame indices
instead.

* Fix formatting
2026-02-25 11:57:07 +01:00
Steven Palma 18d9cb5ac4 feat(scripts): Integrate tqdm for training progress visualization (#3010) 2026-02-24 19:10:43 +01:00
Steven Palma 5095ab0845 fix(ci): permissions triton (#3011) 2026-02-24 19:09:34 +01:00
Bryson Jones cb296ee58f Merge remote-tracking branch 'upstream/chore/bump_transformers_v5' into feature/add-multitask-dit
# Conflicts:
#	pyproject.toml
2026-02-24 09:10:12 -08:00
Bryson Jones 428ea89ff2 Merge branch 'main' into feature/add-multitask-dit 2026-02-24 08:41:50 -08:00
Jash Shah dac1efd13d feat: Enable torch.compile for DiffusionPolicy inference (#2486)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-24 17:29:08 +01:00
Steven Palma b9cb947bd2 style(test): pre-commit 2026-02-24 15:35:25 +01:00
Steven Palma 8440c561ae Merge branch 'main' into feature/add-multitask-dit 2026-02-24 15:34:18 +01:00
Steven Palma 11cefed08a style(test): pre-commit check 2026-02-24 10:57:29 +01:00
Steven Palma 7bfedd1388 test(policies): enable wall x CI testing 2026-02-24 10:53:23 +01:00
Jade Choghari 8c95a71c94 chore: fix XVLA in transformers v5 (#3006) 2026-02-24 10:29:48 +01:00
Steven Palma 1d048c7e2b Merge branch 'main' into chore/bump_transformers_v5 2026-02-23 20:54:02 +01:00
Jade Choghari 419305a4c2 Fix: full pi models support for transformer v5 (#2967)
* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>
2026-02-23 22:44:13 +03:00
Dominik Paľo 7fd71c83a3 docs: add WSL evdev installation note (#2855)
Add a note in the installation guide explaining that users on WSL need to install evdev to avoid build issues.
See: https://github.com/huggingface/lerobot/issues/2528

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-23 20:41:20 +01:00
Yuan Haokuan 0f44adbeec docs: fix HF_USER export command to correctly parse username (#2932)
* Fix HF_USER extraction command in documentation

Updated command to extract the username from hf auth output.

Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com>

* Correct HF_USER variable assignment in documentation

Fix the variable extraction from hf auth output.

Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com>

* Update docs/source/il_robots.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com>

---------

Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-23 17:51:13 +01:00
Guilherme Miotto 7dbbaa3727 Small comment fix (#2990)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-23 17:11:55 +01:00
Yuta Nakagawa fcabfd32a5 chore(docs): update the document for Phone teleop to clarify how to use the examples (#2991)
* update the document for Phone teleope to clarify how to use the examples

* Update docs/source/phone_teleop.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Yuta Nakagawa <ytnkgw@gmail.com>

---------

Signed-off-by: Yuta Nakagawa <ytnkgw@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-23 17:11:46 +01:00
Steven Palma 544cbc5f38 feat(motors): add RobStride CAN implementation (#2821)
* feat(motors): add initial implementation of robstride

Co-authored-by: Virgile <virgilebatto@gmail.com>

* chore(motors): solve some linter

* remove kp/kd attribute

* code uniformisation between damiao and robstride

* remove normalization warning

* remove non valid baudrates and small docstring update

* remove all useless files. Only keeping robstride.py and table.py

* typing for mypy

* reduce NameOrId usage

* align signature with damiao

* put the same helper than in the damiao implementation

* bug correction : expect a response after each bus.send

---------

Co-authored-by: Virgile <virgilebatto@gmail.com>
2026-02-23 16:39:04 +01:00
Yueci Deng a0c5d19391 add metadata_buffer_size to dataset creation (#2998)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-23 16:32:59 +01:00
Steven Palma e96339a3b4 feat(dataset): add streaming video encoding + HW encoder support (#2974)
* feat(dataset): init stream encoding

* feat(dataset): use threads to fix frame pickle latency

* refactor(dataset): remove HW encoded related changes

* add lp (#2977)

* feat(dataset): add Hw encoding + log drop frames (#2978)

* chore(docs): add streaming video encoding guide

* fix(dataset): style docs + testing

* chore(docs): simplify sttreaming video encoding guide

* chore(dataset): add commands + streaming encoding default false + print note if false + queue default is now 30

* chore(docs): add verification note advice

* chore(dataset): adjusting defaults & docs for streaming encoding

* docs(scripts): improve docstrings

* test(dataset): polish streaming encoding tests

* chore(dataset): move FYI log related to streaming

* chore(dataset): add arg vcodec to suggestions

* refactor(dataset): better handling for auto and available vcodec

* chore(dataset): change log level

* docs(dataset): add note related to training performance vcodec

* docs(dataset): add more notes to streaming encoding

---------

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2026-02-23 13:57:43 +01:00
Steven Palma 5865170d36 chore(deps): bump ceil datasets (#2946) 2026-02-20 17:01:46 +01:00
Bryson Jones 23d04ca6fd Merge branch 'main' into feature/add-multitask-dit 2026-02-16 15:38:54 -08:00
Steven Palma 753b996cda test(rl): skip ci tests for resnet10 2026-02-12 21:25:39 +01:00
Steven Palma 099f3ba4d7 fix(policy): xvla forced_bos_token missing 2026-02-12 21:14:53 +01:00
Steven Palma 3f3d08e5a8 chore(style): fix pre-commit 2026-02-12 19:37:57 +01:00
Steven Palma 9e1a67c862 chore(dependencies): bump gr00t to transformers v5 2026-02-12 19:33:56 +01:00
Steven Palma 54c38627bd chore(dependencies): bump wall x to transformers v5 2026-02-12 19:33:34 +01:00
Steven Palma f0ef3717ca chore(dependencies): bump pi0 family to transformers v5 2026-02-12 19:31:23 +01:00
Steven Palma bd8e1ccf70 chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy 2026-02-12 19:17:54 +01:00
Bryson Jones 6243b70239 Merge branch 'main' into feature/add-multitask-dit 2026-02-09 19:38:33 -08:00
Bryson Jones cc4377c346 Merge branch 'main' into feature/add-multitask-dit 2026-02-05 07:52:44 -08:00
Bryson Jones 50e13da845 update docs and readme files, fixing some typos and adding multitask dit to readme 2026-01-31 08:04:24 -08:00
Bryson Jones 7cca09d3da Merge branch 'main' into feature/add-multitask-dit 2026-01-30 23:04:42 -08:00
Bryson Jones ddfd573853 Merge branch 'main' into feature/add-multitask-dit 2026-01-29 10:23:39 -08:00
Bryson Jones 6d87b52abd Merge branch 'main' into feature/add-multitask-dit 2026-01-28 09:20:43 -08:00
Bryson Jones e10e352ca0 Merge branch 'main' into feature/add-multitask-dit 2026-01-27 09:45:39 -08:00
Bryson Jones d755df4461 Merge branch 'main' into feature/add-multitask-dit 2026-01-26 10:20:31 -08:00
Bryson Jones 0840dd6b94 Merge branch 'main' into feature/add-multitask-dit 2026-01-26 08:44:41 -08:00
Bryson Jones 9c4c988573 Merge branch 'main' into feature/add-multitask-dit 2026-01-25 21:44:45 -08:00
Bryson Jones f51e9fc1e4 revert fast tests edits 2026-01-22 08:53:30 -08:00
Bryson Jones 802d73edec Merge branch 'main' into feature/add-multitask-dit 2026-01-22 08:51:17 -08:00
Bryson Jones 3dd51c335e Merge branch 'main' into feature/add-multitask-dit 2026-01-20 21:05:31 -08:00
Bryson Jones b8e36bb2f3 Merge branch 'main' into feature/add-multitask-dit 2026-01-20 07:43:57 -08:00
Bryson Jones 2f3b1b3db8 Merge branch 'main' into feature/add-multitask-dit 2026-01-19 07:59:54 -08:00
Bryson Jones 8f33a1392e update tests script to not use unnecessary uv sync call which resolves dependencies that do not need to run. This drastically reduces CI run time 2026-01-17 09:20:36 -08:00
Bryson Jones f3dc152e94 Merge branch 'main' into feature/add-multitask-dit 2026-01-16 10:14:40 -08:00
Bryson Jones 673e18ab46 Merge branch 'main' into feature/add-multitask-dit 2026-01-16 07:51:55 -08:00
Bryson Jones f4683127b2 Merge branch 'main' into feature/add-multitask-dit 2026-01-15 22:41:41 -08:00
Bryson Jones bdac9d7df8 Merge branch 'main' into feature/add-multitask-dit 2026-01-15 10:13:08 -08:00
Bryson Jones c71b30a06c Merge branch 'main' into feature/add-multitask-dit 2026-01-15 08:01:08 -08:00
Bryson Jones 238ea68382 Merge branch 'main' into feature/add-multitask-dit 2026-01-14 08:30:26 -08:00
Bryson Jones 2abb20262e Merge branch 'main' into feature/add-multitask-dit 2026-01-13 10:02:41 -08:00
Bryson Jones ef3b299145 Merge branch 'main' into feature/add-multitask-dit 2026-01-12 17:14:08 -08:00
Bryson Jones c95d151913 Merge branch 'main' into feature/add-multitask-dit 2026-01-12 10:25:16 -08:00
Bryson Jones 699191f81f Merge branch 'main' into feature/add-multitask-dit 2026-01-12 09:13:46 -08:00
Bryson Jones 6d9271940b Merge branch 'main' into feature/add-multitask-dit 2026-01-12 08:22:31 -08:00
Bryson Jones c24cbaacf9 Merge branch 'main' into feature/add-multitask-dit 2026-01-10 17:04:08 -08:00
Bryson Jones 0e7bfa5624 Merge branch 'main' into feature/add-multitask-dit 2026-01-09 11:03:45 -08:00
Bryson Jones fdcbd1a936 Merge branch 'main' into feature/add-multitask-dit 2026-01-08 11:34:34 -08:00
Bryson Jones dd39503802 Merge branch 'main' into feature/add-multitask-dit 2026-01-08 08:17:15 -08:00
Pepijn db17c08f6e Merge branch 'main' into feature/add-multitask-dit 2026-01-07 18:21:19 +01:00
Bryson Jones b1ab9b9c46 Merge branch 'main' into feature/add-multitask-dit 2026-01-07 08:11:43 -08:00
Bryson Jones 5c6714bc1b Merge branch 'main' into feature/add-multitask-dit 2026-01-06 15:35:04 -08:00
Bryson Jones 7581adf4cd Merge branch 'main' into feature/add-multitask-dit 2026-01-06 11:57:09 -08:00
Pepijn e9384ec834 Merge branch 'main' into feature/add-multitask-dit 2026-01-06 19:20:44 +01:00
Pepijn df52c559d3 Merge branch 'main' into feature/add-multitask-dit 2026-01-06 16:00:12 +01:00
Bryson Jones a011accb7f add conflict management to pyproject toml for pi conflict for mtdp as well 2026-01-05 09:17:02 -08:00
Bryson Jones b98ccdde3a use hyphens for cleanliness in pyproject.toml 2026-01-05 08:59:41 -08:00
Bryson Jones 473528cb14 add wallx dep conflict management for multitask dit policy 2026-01-05 08:53:21 -08:00
Bryson Jones 713efb6427 Merge branch 'main' into feature/add-multitask-dit 2026-01-05 08:39:09 -08:00
Pepijn e268ec1ec5 Merge branch 'main' into feature/add-multitask-dit 2026-01-05 12:10:19 +01:00
Bryson Jones d75f3f8915 Merge branch 'main' into feature/add-multitask-dit 2026-01-04 22:06:54 -08:00
Bryson Jones 2cf13d9d63 Merge branch 'main' into feature/add-multitask-dit 2026-01-01 20:17:14 -08:00
Bryson Jones 8755bd0637 Merge branch 'main' into feature/add-multitask-dit 2025-12-28 04:59:09 -08:00
Bryson Jones 634e3924b8 Merge branch 'main' into feature/add-multitask-dit 2025-12-25 04:15:37 -08:00
Bryson Jones f5f9833540 add kwargs arg to multitask dit constructor 2025-12-25 04:14:56 -08:00
Bryson Jones e2b47a142a Merge branch 'main' into feature/add-multitask-dit 2025-12-23 15:41:18 -08:00
Bryson Jones 2a3444a8dd Merge branch 'main' into feature/add-multitask-dit 2025-12-23 15:06:41 -08:00
Bryson Jones 3e5f31e0be Merge branch 'main' into feature/add-multitask-dit
Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com>
2025-12-23 07:57:56 -08:00
Bryson Jones d653f96420 Merge branch 'main' into feature/add-multitask-dit
Signed-off-by: Bryson Jones <63133702+brysonjones@users.noreply.github.com>
2025-12-22 08:40:08 -08:00
Bryson Jones 2b90763597 Merge branch 'main' into feature/add-multitask-dit 2025-12-20 09:53:01 -08:00
Bryson Jones 632c778a2b Merge branch 'main' into feature/add-multitask-dit 2025-12-18 15:01:40 -08:00
Bryson Jones 77dbc951cd Merge branch 'main' into feature/add-multitask-dit 2025-12-18 08:36:25 -08:00
Bryson Jones 5b9f98169c Merge branch 'main' into feature/add-multitask-dit 2025-12-17 20:22:11 -08:00
Bryson Jones 2128dece82 Merge branch 'main' into feature/add-multitask-dit 2025-12-17 09:25:57 -08:00
Bryson Jones b575632f4f Merge branch 'main' into feature/add-multitask-dit 2025-12-17 07:37:09 -08:00
Bryson Jones 8a2f5aa6cb remove cropping of images smaller than the crop size 2025-12-15 22:20:20 -08:00
Bryson Jones 25ecd16b67 skip tests without transformers 2025-12-15 21:52:42 -08:00
Bryson Jones 1e049fbef7 add test handling for multitask dit when transformers isnt available 2025-12-15 21:21:35 -08:00
Bryson Jones afe2c4d3aa Merge branch 'main' into feature/add-multitask-dit 2025-12-15 18:08:21 -08:00
Bryson Jones 534e143b0c add conditional transformers import to match all other policies that use transformers lib 2025-12-15 10:08:47 -08:00
Bryson Jones 23382c0ac5 add multitask dit to toc for docs 2025-12-15 10:02:00 -08:00
Bryson Jones 4eda54c7fe Merge branch 'main' into feature/add-multitask-dit 2025-12-15 09:57:27 -08:00
Bryson Jones a632dd3af4 Merge branch 'main' into feature/add-multitask-dit 2025-12-15 08:22:05 -08:00
Bryson Jones e4a1b27fd3 Merge branch 'main' into feature/add-multitask-dit 2025-12-14 09:10:48 -08:00
Bryson Jones 71f359ca6e refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit 2025-12-11 12:09:54 -08:00
Bryson Jones 51dfee43f4 rename config param for multiple vision encoders 2025-12-11 09:58:57 -08:00
Bryson Jones 1f74982469 note origins of each training objective 2025-12-11 09:55:41 -08:00
Bryson Jones 8e3a1e8945 add more descriptions and depth to multitask dit tutorial 2025-12-11 09:45:53 -08:00
Bryson Jones 43c335d0d7 reformat and clean up tutorial for multitask dit policy 2025-12-11 09:33:30 -08:00
Bryson Jones dd4ef1383f fix nit formatting in generate actions fcn 2025-12-11 09:21:26 -08:00
Bryson Jones f3823e8bcd remove the base classes since we don't need to be able to extend 2025-12-11 09:20:25 -08:00
Bryson Jones c398a146b3 use constants for indexing into batches and remove env state references 2025-12-11 09:13:38 -08:00
Bryson Jones 9b47c5fac9 move policy to top of file 2025-12-11 09:04:38 -08:00
Bryson Jones 56dbeed89f add processor tests to multitask dit tests 2025-12-11 09:04:23 -08:00
Bryson Jones 67b1a9eeb1 update typo in test instruction comment 2025-12-11 08:28:35 -08:00
Bryson Jones 86e0ee787d remove environment state conditioning 2025-12-11 08:26:21 -08:00
Bryson Jones ba968e84f1 fix bugs when testing on hardware 2025-12-10 16:26:57 -08:00
Bryson Jones d49d3390f6 Merge branch 'main' into feature/add-multitask-dit 2025-12-10 14:44:04 -08:00
Bryson Jones f1ac454800 add tutorial to training with multi_task_dit 2025-12-10 14:43:27 -08:00
Bryson Jones 10cfc17705 remove redundant asserts 2025-12-10 14:43:27 -08:00
Bryson Jones 5524a0d7a7 split up select action return statement 2025-12-10 14:43:27 -08:00
Bryson Jones 3a16a002f8 add torch.no_grad decorators 2025-12-10 14:43:27 -08:00
Bryson Jones 3b2a4f548c merge all modules files into the main modeling file 2025-12-10 14:43:27 -08:00
Bryson Jones b92dc82ddd add references to the modeling file comments 2025-12-10 14:43:27 -08:00
Bryson Jones 103230c64c simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives 2025-12-10 14:43:27 -08:00
Bryson Jones cdacc090cd update docstring for multitask dit policy processor file 2025-12-10 14:43:27 -08:00
Bryson Jones 6f856016c5 adjust factory comment 2025-12-10 14:43:27 -08:00
Bryson Jones adabb37af6 remove dino vision encoder and simplify text and vision encoders by removing inheritance structure 2025-12-10 14:43:27 -08:00
Bryson Jones 55e19ff9a7 update readme and citations for multitask dit policy 2025-12-10 14:43:27 -08:00
Bryson Jones 22714af08d Merge branch 'main' into feature/add-multitask-dit 2025-12-09 20:45:59 -08:00
Bryson Jones 46ebcc2f7d add RoPE attention module as this is shown to help training dynamics and generation quality for DiTs 2025-12-09 08:42:56 -08:00
Bryson Jones a0d5a088e3 Merge branch 'main' into feature/add-multitask-dit 2025-12-09 07:44:05 -08:00
Bryson Jones 34499cbc1b Merge branch 'main' into feature/add-multitask-dit 2025-11-28 16:45:07 -08:00
Bryson Jones 8b9fada80f expand the observation encoder to support differnt size encoders for vision and text 2025-11-21 14:31:35 -08:00
Bryson Jones ab97d5c019 Merge branch 'main' into feature/add-multitask-dit 2025-11-13 11:23:37 -08:00
Bryson Jones 14a7a4d7d4 Add multitask diffusion transformer policy
Add multitask diffusion transformer policy
2025-11-12 16:20:59 -08:00
153 changed files with 8040 additions and 3827 deletions
+8 -1
View File
@@ -44,7 +44,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
@@ -61,6 +61,7 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -89,5 +90,11 @@ jobs:
- name: Install lerobot with test extras
run: uv sync --extra "test"
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest
run: uv run pytest tests -vv --maxfail=10
+16 -1
View File
@@ -37,7 +37,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest action is built, canceling older runs.
@@ -60,6 +60,7 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -87,6 +88,12 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
@@ -162,6 +169,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -173,6 +181,13 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Fix ptxas permissions
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
+20 -4
View File
@@ -28,7 +28,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
@@ -119,6 +119,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
@@ -130,6 +131,11 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on CPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -146,6 +152,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -157,6 +164,11 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -174,6 +186,7 @@ jobs:
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
CUDA_VISIBLE_DEVICES: "0,1,2,3"
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -185,12 +198,15 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Verify GPU availability
run: |
nvidia-smi
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10
run: pytest -vv tests/training/
+1 -1
View File
@@ -50,7 +50,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
- name: Run pre-commit hooks
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
+2 -10
View File
@@ -22,7 +22,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
jobs:
# This job builds the Python package and publishes it to PyPI
@@ -45,7 +45,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
- name: Extract Version
id: extract_info
@@ -83,14 +83,6 @@ jobs:
exit 1
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
run: python -m pip install build
+13 -2
View File
@@ -29,7 +29,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
# Ensures that only the latest action is built, canceling older runs.
@@ -48,6 +48,7 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -79,7 +80,11 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv
@@ -137,6 +142,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -148,6 +154,11 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv
- name: Run end-to-end tests
+2 -2
View File
@@ -13,7 +13,7 @@
# limitations under the License.
default_language_version:
python: python3.10
python: python3.12
exclude: "tests/artifacts/.*\\.safetensors$"
@@ -55,7 +55,7 @@ repos:
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]
args: [--py312-plus]
##### Markdown Quality #####
- repo: https://github.com/rbubley/mirrors-prettier
+25
View File
@@ -0,0 +1,25 @@
# AI Usage Policy
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
## Remember the Human Maintainers
Please remember that LeRobot is maintained by a dedicated team of humans.
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
## AI is Welcome Here
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
+1 -1
View File
@@ -2,7 +2,7 @@
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
## Ways to Contribute
+1
View File
@@ -1,2 +1,3 @@
include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json
+5 -5
View File
@@ -100,11 +100,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
+3 -1
View File
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
# Define Python version argument
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.12
# Configure environment variables
ENV DEBIAN_FRONTEND=noninteractive \
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
RUN uv pip install --no-cache ".[all]"
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
# Copy the rest of the application source code
# Make sure to have the git-LFS files for testing
COPY --chown=user_lerobot:user_lerobot . .
+1 -1
View File
@@ -19,7 +19,7 @@
# docker run -it --rm lerobot-user
# Configure the base image
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.12
FROM python:${PYTHON_VERSION}-slim
# Configure environment variables
+4
View File
@@ -29,6 +29,8 @@
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
- sections:
- local: act
@@ -45,6 +47,8 @@
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
title: Multitask DiT Policy
- local: walloss
title: WALL-OSS
title: "Policies"
+3
View File
@@ -88,5 +88,8 @@ lerobot-record \
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
--dataset.num_episodes=10 \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--policy.path=${HF_USER}/act_policy
```
+1 -1
View File
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
+4 -4
View File
@@ -32,7 +32,7 @@ version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.11"
requires-python = ">= 3.12"
[build-system]
build-backend = # your-build-backend
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
# modeling_my_custom_policy.py
import torch
import torch.nn as nn
from typing import Dict, Any
from typing import Any
from lerobot.policies.pretrained import PreTrainedPolicy
from .configuration_my_custom_policy import MyCustomPolicyConfig
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
...
```
@@ -102,7 +102,7 @@ Create processor functions:
```python
# processor_my_custom_policy.py
from typing import Dict, Any
from typing import Any
import torch
+6 -3
View File
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
### Hardware
- EarthRover Mini robot
- Computer with Python 3.10 or newer
- Computer with Python 3.12 or newer
- Internet connection
### Setting Up the Frodobots SDK
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face username:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
@@ -192,6 +192,9 @@ lerobot-record \
--dataset.num_episodes=2 \
--dataset.fps=10 \
--dataset.single_task="Navigate around obstacles" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
+2 -2
View File
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
pip install huggingface_hub
# Login to Hugging Face
huggingface-cli login
hf auth login
# Create a new repository
huggingface-cli repo create my-custom-env --type space --org my-org
hf repo create my-org/my-custom-env
# Initialize git and push
git init
+6 -3
View File
@@ -120,9 +120,12 @@ lerobot-record \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm"
--policy.path=<user>/groot-bimanual # your trained model
--dataset.episode_time_s=30
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10
```
+6
View File
@@ -230,6 +230,9 @@ lerobot-record \
--dataset.episode_time_s=5 \
--dataset.push_to_hub=true \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
@@ -273,5 +276,8 @@ lerobot-record \
--dataset.repo_id=<USER>/eval_hopejr \
--dataset.single_task="Evaluate hopejr hand policy" \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
```
+12 -6
View File
@@ -159,13 +159,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(hf auth whoami | head -n 1)
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
@@ -185,7 +185,10 @@ lerobot-record \
--display_data=true \
--dataset.repo_id=${HF_USER}/record-test \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube"
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
</hfoption>
<hfoption id="API example">
@@ -324,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
You can also push your local dataset to the Hub manually, running:
```bash
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
```
#### Record function
@@ -488,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
Once training is done, upload the latest checkpoint with:
```bash
huggingface-cli upload ${HF_USER}/act_so101_test \
hf upload ${HF_USER}/act_so101_test \
outputs/train/act_so101_test/checkpoints/last/pretrained_model
```
@@ -496,7 +499,7 @@ You can also upload intermediate checkpoints with:
```bash
CKPT=010000
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
hf upload ${HF_USER}/act_so101_test${CKPT} \
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
```
@@ -515,6 +518,9 @@ lerobot-record \
--display_data=false \
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
+10 -6
View File
@@ -1,6 +1,6 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
@@ -11,10 +11,10 @@ bash Miniforge3-$(uname)-$(uname -m).sh
## Step 2: Environment Setup
Create a virtual environment with Python 3.10, using conda:
Create a virtual environment with Python 3.12, using conda:
```bash
conda create -y -n lerobot python=3.10
conda create -y -n lerobot python=3.12
```
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
@@ -40,6 +40,13 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
> [!NOTE]
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command:
>
> ```bash
> conda install evdev -c conda-forge
> ```
## Step 3: Install LeRobot 🤗
### From Source
@@ -83,9 +90,6 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
+2 -2
View File
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
+4 -1
View File
@@ -41,7 +41,10 @@ lerobot-record \
--display_data=true \
--dataset.repo_id=${HF_USER}/record-test \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube"
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
See the [recording guide](./il_robots#record-a-dataset) for more details.
+340
View File
@@ -0,0 +1,340 @@
# Multitask DiT Policy
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
## Model Overview
The model uses:
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.
## Installation Requirements
Multitask DiT Policy has additional dependencies. Install it with:
```bash
pip install lerobot[multi_task_dit]
```
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
## Usage
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
```python
policy.type=multi_task_dit
```
## Training
### Basic Training Command
Here's a complete training command for training Multitask DiT on your dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/multitask_dit_training \
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
For reliable performance, start with these suggested default hyperparameters:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
**Key Parameters:**
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task
### Training Configuration Parameters
#### Objective Selection
Choose between diffusion and flow matching:
```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
--policy.clip_sample=true \ # Clip samples during denoising
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
```
#### Transformer Architecture
Adjust model capacity based on dataset size:
```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
```
**Positional Encoding Options:**
The model supports two positional encoding methods for action sequences:
```bash
# Rotary Position Embedding (RoPE) - default, recommended
--policy.use_rope=true \
--policy.rope_base=10000.0 # Base frequency for RoPE
# Absolute positional encoding
--policy.use_positional_encoding=true # Disables RoPE when true
```
**Other Transformer Parameters:**
```bash
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
--policy.timestep_embed_dim=256 # Timestep embedding dimension
```
#### Vision Encoder Configuration
```bash
# Use different CLIP model for more expressivity at the cost of inference time
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
--policy.vision_encoder_name=openai/clip-vit-large-patch14
# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```
#### Text Encoder Configuration
```bash
# Use different CLIP text encoder model
# same as vision: experiment with larger or smaller models depending on the
# complexity of your tasks and size of dataset
--policy.text_encoder_name=openai/clip-vit-large-patch14
```
#### Learning Rate Configuration
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```
### Training Tuning Guidelines
#### 1. Flow Matching with Beta Sampling
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
Consider testing the flow-matching objective and evaluating performance differences for your task:
```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
#### 2. Number of Transformer Layers
Match model capacity to your dataset size:
- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers
#### 3. `horizon` Tuning
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
#### 4. `n_action_steps` Sensitivity
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
### Inference Tuning
For faster inference, use DDIM with fewer sampling steps:
```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```
### Resuming Training
To resume training from a checkpoint:
```bash
lerobot-train \
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
## Common Failure Modes and Debugging
Training these models can be finicky. Here are common failure modes and debugging approaches:
### Idling / No Motion
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
**Debugging tips:**
- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction
### Executing the Wrong Task
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
**Potential causes:**
- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset
**Debugging tips:**
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning
### Training Instability
If training loss is unstable or diverging:
- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly
## Performance Considerations
### GPU Requirements
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
### Batch Size Recommendations
- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)
## Example: Training on Custom Dataset
Here's a complete example training on a custom dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true \
--wandb.project=multitask_dit
```
## References
For more details on the technical implementation and architecture, see:
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)
+9 -5
View File
@@ -66,12 +66,13 @@ Run on of the examples scripts to teleoperate, record a dataset, replay a datase
All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port.
Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf)
Additionally you need to **copy the URDF of the robot into the examples folder**. For the examples in this tutorial (using SO100/SO101), copy the `SO101` folder from the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101) into the `examples/phone_to_so100/` directory, so that the URDF file path becomes `examples/phone_to_so100/SO101/so101_new_calib.urdf`.
- Run this example to teleoperate:
```bash
python examples/phone_to_so100/teleoperate.py
cd examples/phone_to_so100
python teleoperate.py
```
After running the example:
@@ -84,19 +85,22 @@ Additionally you can customize mapping or safety limits by editing the processor
- Run this example to record a dataset, which saves absolute end effector observations and actions:
```bash
python examples/phone_to_so100/record.py
cd examples/phone_to_so100
python record.py
```
- Run this example to replay recorded episodes:
```bash
python examples/phone_to_so100/replay.py
cd examples/phone_to_so100
python replay.py
```
- Run this example to evaluate a pretrained policy:
```bash
python examples/phone_to_so100/evaluate.py
cd examples/phone_to_so100
python evaluate.py
```
### Important pipeline steps and options
-5
View File
@@ -34,11 +34,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
-5
View File
@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
+10 -15
View File
@@ -43,16 +43,11 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training a Custom FAST Tokenizer
You have two options for the FAST tokenizer:
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
@@ -114,15 +109,15 @@ lerobot-train \
### Key Training Parameters
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
## Inference
@@ -0,0 +1,37 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```
+6
View File
@@ -159,6 +159,9 @@ lerobot-record \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
@@ -198,6 +201,9 @@ lerobot-record \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
+3
View File
@@ -106,6 +106,9 @@ lerobot-record \
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
--dataset.episode_time_s=50 \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
+155
View File
@@ -0,0 +1,155 @@
# Streaming Video Encoding Guide
## 1. Overview
Streaming video encoding eliminates the traditional PNG round-trip during video dataset recording. Instead of:
1. Capture frame -> write PNG to disk -> (at episode end) read PNG's -> encode to MP4 -> delete PNG's
Frames can be encoded in real-time during capture:
1. Capture frame -> queue to encoder thread -> encode to MP4 directly
This makes `save_episode()` near-instant (the video is already encoded by the time the episode ends) and removes the blocking wait that previously occurred between episodes, especially with multiple cameras in long episodes.
## 2. Tuning Parameters
| Parameter | CLI Flag | Type | Default | Description |
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
## 3. Performance Considerations
Streaming encoding means the CPU is encoding video **during** the capture loop, not after. This creates a CPU budget that must be shared between:
- **Control loop** (reading cameras, control the robot, writing non-video data)
- **Encoder threads** (one pool per camera)
- **Rerun visualization** (if enabled)
- **OS and other processes**
### Resolution & Number of Cameras Impact
| Setup | Throughput (px/sec) | CPU Encoding Load | Notes |
| ------------------------- | ------------------- | ----------------- | ------------------------------ |
| 2camsx 640x480x3 @30fps | 55M | Low | Works on most systems |
| 2camsx 1280x720x3 @30fps | 165M | Moderate | Comfortable on modern systems |
| 2camsx 1920x1080x3 @30fps | 373M | High | Requires powerful high-end CPU |
### `encoder_threads` Tuning
This parameter controls how many threads each encoder instance uses internally:
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
- **`None` (default)**: Lets the codec decide. Information available in the codec logs.
### Backpressure and Frame Dropping
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
1. The queue fills up (consuming RAM)
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
3. A warning is logged: `"Encoder queue full for {camera}, dropped N frame(s)"`
4. At episode end, total dropped frames per camera are reported
### Symptoms of Encoder Falling Behind
- **System feels laggy and freezes**: all CPUs are at 100%
- **Dropped frame warnings** in the log or lower frames/FPS than expected in the recorded dataset
- **Choppy robot movement**: If CPU is severely overloaded, even the capture loop may be affected
- **Accumulated rerun lag**: Visualization falls behind real-time
## 4. Hardware-Accelerated Encoding
### When to Use
Use HW encoding when:
- CPU is the bottleneck (dropped frames, choppy robot, rerun lag)
- You have compatible hardware (GPU or dedicated encoder)
- You're recording at high throughput (high resolution or with many cameras)
### Choosing a Codec
| Codec | CPU Usage | File Size | Quality | Notes |
| --------------------- | --------- | -------------- | ------- | ---------------------------------------------------------------- |
| `libsvtav1` (default) | High | Smallest | Best | Default. Best compression but most CPU-intensive |
| `h264` | Medium | ~30-50% larger | Good | Software H.264. Lower CPU |
| HW encoders | Very Low | Largest | Good | Offloads to dedicated hardware. Best for CPU-constrained systems |
### Available HW Encoders
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
> [!NOTE]
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
> [!NOTE]
> `libsvtav1` is the default because it provides the best training performance; other vcodecs can reduce CPU usage and be faster, but they typically produce larger files and may affect training time.
## 5. Troubleshooting
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
## 6. Recommended Configurations
These estimates are conservative; we recommend testing them on your setup—start with a low load and increase it gradually.
### High-End Systems: modern 12+ cores (24+ threads)
A throughput between ~250-500M px/sec should be comfortable in CPU. For even better results try HW encoding if available.
```bash
# 3camsx 1280x720x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
# 2camsx 1920x1080x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
lerobot-record --dataset.encoder_threads=5 ...
# 3camsx 1920x1080x3 @30fps: Might require some tuning.
```
### Mid-Range Systems: modern 8+ cores (16+ threads) or Apple Silicon
A throughput between ~80-300M px/sec should be possible in CPU.
```bash
# 3camsx 640x480x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
# 2camsx 1280x720x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
lerobot-record --dataset.encoder_threads=2 ...
# 2camsx 1920x1080x3 @30fps: Might require some tuning.
```
### Low-Resource Systems: modern 4+ cores (8+ threads) or Raspberry Pi 5
On very constrained systems, streaming encoding may compete too heavily with the capture loop. Disabling it falls back to the PNG-based approach where encoding happens between episodes (blocking, but doesn't interfere with capture). Alternatively, record at a lower throughput to reduce both capture and encoding load. Consider also changing codec to `h264` and using batch encoding.
```bash
# 2camsx 640x480x3 @30fps: Requires some tuning.
# Use H.264, disable streaming, consider batching encoding
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
```
## 7. Closing note
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
+10 -4
View File
@@ -123,7 +123,7 @@ SSH into the robot and install LeRobot:
```bash
ssh unitree@<YOUR_ROBOT_IP>
conda create -y -n lerobot python=3.10
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
@@ -153,7 +153,7 @@ With the robot server running, you can now control the robot remotely. Let's lau
### Step 1: Install LeRobot on your machine
```bash
conda create -y -n lerobot python=3.10
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
@@ -229,7 +229,10 @@ lerobot-record \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
@@ -279,7 +282,10 @@ lerobot-record \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
**Note**: Update `server_address` to match your robot's camera server IP.
+1 -1
View File
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
+490
View File
@@ -0,0 +1,490 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SLURM-distributed SARM RA-BC annotation pipeline.
Computes SARM progress values for all frames in a dataset, distributed across
SLURM workers, then merges the shards into a single sarm_progress.parquet.
Two subcommands, each a separate SLURM submission:
compute N workers, each computes progress for a subset of episodes
aggregate 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
Usage:
python slurm_compute_rabc.py compute \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--stride 10 --device cpu --workers 50 --partition cpu
python slurm_compute_rabc.py aggregate \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--partition cpu --push-to-hub
"""
import argparse
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
class ComputeProgressShards(PipelineStep):
"""Each worker computes SARM progress for its assigned episodes."""
def __init__(
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
):
super().__init__()
if stride < 1:
raise ValueError(f"stride must be >= 1, got {stride}")
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.stride = stride
self.head_mode = head_mode
self.device = device
self.shard_dir = shard_dir
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.policies.sarm.compute_rabc_weights import (
generate_all_frame_indices,
interpolate_progress,
load_sarm_resources,
)
from lerobot.utils.utils import init_logging
init_logging()
dataset, reward_model, preprocess = load_sarm_resources(
self.repo_id,
self.reward_model_path,
self.device,
)
if hasattr(preprocess, "eval"):
preprocess.eval()
for step in preprocess.steps:
if hasattr(step, "eval"):
step.eval()
image_key = reward_model.config.image_key
state_key = reward_model.config.state_key
frame_gap = reward_model.config.frame_gap
center_idx = reward_model.config.n_obs_steps // 2
dual_mode = reward_model.config.uses_dual_heads
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
compute_dense = self.head_mode in ("dense", "both") and dual_mode
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
if not my_episodes:
logging.info(f"Rank {rank}: no episodes assigned")
return
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
all_rows = []
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
ep = dataset.meta.episodes[ep_idx]
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
task = dataset[ep_start].get("task", "perform the task")
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
if self.stride > 1:
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
if (ep_end - 1) not in compute_indices:
compute_indices.append(ep_end - 1)
compute_indices = sorted(set(compute_indices))
else:
compute_indices = all_ep_indices
frame_results = {}
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
try:
sample = dataset[qi]
batch = {
image_key: sample[image_key],
"task": task,
"index": qi,
"episode_index": ep_idx,
}
if state_key in sample:
batch[state_key] = sample[state_key]
with torch.no_grad():
processed = preprocess(batch)
vf = processed["video_features"].to(self.device)
tf = processed["text_features"].to(self.device)
sf = processed.get("state_features")
if sf is not None:
sf = sf.to(self.device)
lengths = processed.get("lengths")
sparse_val = dense_val = np.nan
if compute_sparse:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="sparse",
)
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
if compute_dense:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="dense",
)
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
frame_results[qi] = (sparse_val, dense_val)
except Exception as e:
logging.warning(f"Failed frame {qi}: {e}")
if not frame_results:
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
continue
# Interpolate to all frames in this episode
computed_idx = np.array(sorted(frame_results.keys()))
all_frame_arr = np.arange(ep_start, ep_end)
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
if self.stride > 1 and len(computed_idx) > 1:
if compute_sparse:
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
if compute_dense:
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
output_frames = all_frame_arr
else:
# Use only successfully computed frames to avoid indexing mismatch on failures
output_frames = computed_idx
for i, fi in enumerate(output_frames):
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
if compute_sparse:
row["progress_sparse"] = float(sparse_vals[i])
if compute_dense:
row["progress_dense"] = float(dense_vals[i])
all_rows.append(row)
if all_rows:
import pandas as pd
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
shard_dir = Path(self.shard_dir)
shard_dir.mkdir(parents=True, exist_ok=True)
out = shard_dir / f"shard_{rank:05d}.parquet"
pq.write_table(table, out)
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
class AggregateProgress(PipelineStep):
"""Merge all shard parquets into final sarm_progress.parquet."""
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
super().__init__()
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.shard_dir = shard_dir
self.push_to_hub = push_to_hub
def run(self, data=None, rank: int = 0, world_size: int = 1):
import datetime
import logging
import os
from pathlib import Path
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.utils import init_logging
init_logging()
if rank != 0:
return
shard_dir = Path(self.shard_dir)
shards = sorted(shard_dir.glob("shard_*.parquet"))
if not shards:
raise FileNotFoundError(f"No shards found in {shard_dir}")
# Log shard modification time range to help detect stale files
mtimes = [os.path.getmtime(s) for s in shards]
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
df = df.sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
out_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out_path)
logging.info(f"Saved {len(df)} rows to {out_path}")
for col in ["progress_sparse", "progress_dense"]:
if col in df.columns:
v = df[col].dropna()
logging.info(
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
)
if self.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = "sarm_progress.parquet"
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(out_path),
path_in_repo=hub_path,
repo_id=self.repo_id,
repo_type="dataset",
)
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
def make_compute_executor(
repo_id,
reward_model_path,
stride,
head_mode,
device,
shard_dir,
logs_dir,
job_name,
slurm,
workers,
partition,
cpus_per_task,
mem_per_cpu,
):
kwargs = {
"pipeline": [
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": workers,
"workers": workers,
"time": "24:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": workers, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def make_aggregate_executor(
repo_id,
reward_model_path,
shard_dir,
logs_dir,
job_name,
slurm,
partition,
cpus_per_task,
mem_per_cpu,
push_to_hub,
):
kwargs = {
"pipeline": [
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": 1,
"workers": 1,
"time": "02:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": 1, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def _add_shared_args(p):
p.add_argument(
"--repo-id",
type=str,
required=True,
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
)
p.add_argument(
"--shard-dir",
type=Path,
default=Path("rabc_shards"),
help="Directory to read/write per-rank parquet shards.",
)
p.add_argument(
"--logs-dir",
type=Path,
default=Path("logs"),
help="Directory for datatrove logs.",
)
p.add_argument(
"--job-name",
type=str,
default=None,
help="SLURM job name (defaults to rabc_<subcommand>).",
)
p.add_argument(
"--slurm",
type=int,
default=1,
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
)
p.add_argument(
"--partition",
type=str,
default=None,
help="SLURM partition to submit to.",
)
p.add_argument(
"--cpus-per-task",
type=int,
default=4,
help="Number of CPUs per SLURM task.",
)
p.add_argument(
"--mem-per-cpu",
type=str,
default="4G",
help="Memory per CPU, e.g. '4G' or '1950M'.",
)
def main():
parser = argparse.ArgumentParser(
description="SLURM-distributed SARM RA-BC annotation pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
sub = parser.add_subparsers(dest="command", required=True)
# compute subcommand
cp = sub.add_parser(
"compute",
help="Distribute progress computation across SLURM workers.",
)
_add_shared_args(cp)
cp.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model.",
)
cp.add_argument(
"--stride",
type=int,
default=1,
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
)
cp.add_argument(
"--head-mode",
type=str,
default="sparse",
choices=["sparse", "dense", "both"],
help="Which reward head(s) to compute.",
)
cp.add_argument(
"--device",
type=str,
default="cpu",
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
)
cp.add_argument(
"--workers",
type=int,
default=50,
help="Number of parallel SLURM tasks (one shard per worker).",
)
# aggregate subcommand
ap = sub.add_parser(
"aggregate",
help="Merge per-rank shards into a single sarm_progress.parquet.",
)
_add_shared_args(ap)
ap.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
)
ap.add_argument(
"--push-to-hub",
action="store_true",
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
)
args = parser.parse_args()
job_name = args.job_name or f"rabc_{args.command}"
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
kwargs["job_name"] = job_name
command = kwargs.pop("command")
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
executor.run()
if __name__ == "__main__":
main()
+14 -17
View File
@@ -4,6 +4,7 @@ from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
@@ -11,7 +12,6 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -40,9 +40,8 @@ def run_learner(
policy_learner.train()
policy_learner.to(device)
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
algorithm.make_optimizers()
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
@@ -84,26 +83,24 @@ def run_learner(
else:
batch[key] = online_batch[key]
def batch_iter(b=batch):
while True:
yield b
loss, _ = policy_learner.forward(batch)
stats = algorithm.update(batch_iter())
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_step += 1
if training_step % LOG_EVERY == 0:
log_dict = stats.to_log_dict()
print(
f"[LEARNER] Training step {training_step}, "
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
weights = algorithm.get_weights()
parameters_queue.put_nowait(weights)
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
@@ -147,15 +144,15 @@ def run_actor(
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_weights = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_weights)
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy (returns full action: continuous + discrete)
# Get action from policy
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
+56 -119
View File
@@ -25,11 +25,11 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.4"
version = "0.4.5"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
requires-python = ">=3.10"
requires-python = ">=3.12"
authors = [
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
@@ -50,7 +50,8 @@ classifiers = [
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Build Tools",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
@@ -59,28 +60,30 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Hugging Face dependencies
"datasets>=4.0.0,<4.2.0",
"datasets>=4.0.0,<5.0.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"huggingface-hub>=1.0.0,<2.0.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"setuptools>=71.0.0,<81.0.0",
"cmake>=3.29.0.1,<4.2.0",
"packaging>=24.2,<26.0",
"torch>=2.2.1,<2.11.0",
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0,<0.26.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pynput>=1.7.8,<1.9.0",
"pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0",
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
"draccus==0.10.0", # TODO: Remove ==
"draccus==0.10.0", # TODO: Relax version constraint
"gymnasium>=1.1.1,<2.0.0",
"rerun-sdk>=0.24.0,<0.27.0",
@@ -95,14 +98,20 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["python-can>=4.2.0,<5.0.0"]
damiao = ["lerobot[can-dep]"]
robstride = ["lerobot[can-dep]"]
# Robots
openarms = ["lerobot[damiao]"]
@@ -114,30 +123,31 @@ unitree_g1 = [
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"matplotlib>=3.9.0,<4.0.0",
"lerobot[matplotlib-dep]",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
# Policies
wallx = [
"transformers==4.49.0",
"peft==0.17.1",
"scipy==1.15.3",
"torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11"
"lerobot[transformers-dep]",
"lerobot[peft]",
"lerobot[scipy-dep]",
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
"lerobot[peft]",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
@@ -146,13 +156,13 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -160,13 +170,19 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
metaworld = ["metaworld==3.0.0"]
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
all = [
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
# helps pip's resolver converge by constraining scipy early, before it encounters
# the loose scipy requirements from transitive deps like dm-control and metaworld.
"scipy>=1.14.0,<2.0.0",
"lerobot[dynamixel]",
"lerobot[gamepad]",
"lerobot[hopejr]",
@@ -174,8 +190,8 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
# "lerobot[wallx]",
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
@@ -187,7 +203,7 @@ all = [
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[phone]",
"lerobot[libero]",
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[peft]",
@@ -212,11 +228,14 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
[tool.ruff]
target-version = "py310"
target-version = "py312"
line-length = 110
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
@@ -308,7 +327,7 @@ default.extend-ignore-identifiers-re = [
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
[tool.mypy]
python_version = "3.10"
python_version = "3.12"
ignore_missing_imports = true
follow_imports = "skip"
# warn_return_any = true
@@ -392,85 +411,3 @@ ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"
# ignore_errors = false
[tool.uv]
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
conflicts = [
[
{ extra = "wallx" },
{ extra = "transformers-dep" },
],
[
{ extra = "wallx" },
{ extra = "pi" },
],
[
{ extra = "wallx" },
{ extra = "smolvla" },
],
[
{ extra = "wallx" },
{ extra = "groot" },
],
[
{ extra = "wallx" },
{ extra = "xvla" },
],
[
{ extra = "wallx" },
{ extra = "sarm" },
],
[
{ extra = "wallx" },
{ extra = "hilserl" },
],
[
{ extra = "wallx" },
{ extra = "libero" },
],
[
{ extra = "wallx" },
{ extra = "peft" },
],
[
{ extra = "wallx" },
{ extra = "all" },
],
# pi uses custom branch which conflicts with transformers-dep
[
{ extra = "pi" },
{ extra = "transformers-dep" },
],
[
{ extra = "pi" },
{ extra = "smolvla" },
],
[
{ extra = "pi" },
{ extra = "groot" },
],
[
{ extra = "pi" },
{ extra = "xvla" },
],
[
{ extra = "pi" },
{ extra = "sarm" },
],
[
{ extra = "pi" },
{ extra = "hilserl" },
],
[
{ extra = "pi" },
{ extra = "libero" },
],
[
{ extra = "pi" },
{ extra = "peft" },
],
[
{ extra = "pi" },
{ extra = "all" },
],
]
+5 -3
View File
@@ -63,9 +63,9 @@ from lerobot.transport import (
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
from lerobot.utils.import_utils import register_third_party_plugins
from .configs import RobotClientConfig
from .constants import SUPPORTED_ROBOTS
from .helpers import (
Action,
FPSTracker,
@@ -485,8 +485,9 @@ class RobotClient:
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
# TODO: Assert if checking robot support is still needed with the plugin system
# if cfg.robot.type not in SUPPORTED_ROBOTS:
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
@@ -512,4 +513,5 @@ def async_client(cfg: RobotClientConfig):
if __name__ == "__main__":
register_third_party_plugins()
async_client() # run the client
+1 -1
View File
@@ -27,7 +27,7 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
-12
View File
@@ -211,15 +211,3 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm name registered in RLAlgorithmConfig registry
algorithm: str = "sac"
# Data mixer strategy name. Currently supports "online_offline"
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer
online_ratio: float = 0.5
# RL trainer iterator
async_prefetch: bool = True
queue_size: int = 2
+3 -1
View File
@@ -289,7 +289,9 @@ def aggregate_datasets(
logging.info("Find all tasks")
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
dst_meta.tasks = pd.DataFrame(
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
)
meta_idx = {"chunk": 0, "file": 0}
data_idx = {"chunk": 0, "file": 0}
+7
View File
@@ -7,6 +7,13 @@
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
{% if repo_id is defined and repo_id %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
## Dataset Description
{{ dataset_description | default("", true) }}
+41 -33
View File
@@ -89,8 +89,8 @@ def delete_episodes(
Args:
dataset: The source LeRobotDataset.
episode_indices: List of episode indices to delete.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
"""
if not episode_indices:
raise ValueError("No episodes to delete")
@@ -152,7 +152,7 @@ def split_dataset(
dataset: The source LeRobotDataset to split.
splits: Either a dict mapping split names to episode indices, or a dict mapping
split names to fractions (must sum to <= 1.0).
output_dir: Base directory for output datasets. If None, uses default location.
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
Examples:
Split by specific episodes
@@ -243,8 +243,8 @@ def merge_datasets(
Args:
datasets: List of LeRobotDatasets to merge.
output_repo_id: Repository ID for the merged dataset.
output_dir: Directory to save the merged dataset. If None, uses default location.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -288,8 +288,8 @@ def modify_features(
dataset: The source LeRobotDataset.
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
remove_features: Optional feature name(s) to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with features modified.
@@ -390,8 +390,8 @@ def add_features(
Args:
dataset: The source LeRobotDataset.
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with all features added.
@@ -427,8 +427,8 @@ def remove_feature(
Args:
dataset: The source LeRobotDataset.
feature_names: Name(s) of features to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with features removed.
@@ -567,20 +567,22 @@ def _copy_and_reindex_data(
def _keep_episodes_from_video_with_av(
input_path: Path,
output_path: Path,
episodes_to_keep: list[tuple[float, float]],
episodes_to_keep: list[tuple[int, int]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> None:
"""Keep only specified episodes from a video file using PyAV.
This function decodes frames from specified time ranges and re-encodes them with
This function decodes frames from specified frame ranges and re-encodes them with
properly reset timestamps to ensure monotonic progression.
Args:
input_path: Source video file path.
output_path: Destination video file path.
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
@@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av(
# Create set of (start, end) ranges for fast lookup.
# Convert to a sorted list for efficient checking.
time_ranges = sorted(episodes_to_keep)
frame_ranges = sorted(episodes_to_keep)
# Track frame index for setting PTS and current range being processed.
src_frame_count = 0
frame_count = 0
range_idx = 0
@@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av(
if frame is None:
continue
# Get frame timestamp.
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
# Check if frame is in any of our desired time ranges.
# Check if frame is in any of our desired frame ranges.
# Skip ranges that have already passed.
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
range_idx += 1
# If we've passed all ranges, stop processing.
if range_idx >= len(time_ranges):
if range_idx >= len(frame_ranges):
break
# Check if frame is in current range.
start_ts, end_ts = time_ranges[range_idx]
if frame_time < start_ts:
start_frame = frame_ranges[range_idx][0]
if src_frame_count < start_frame:
src_frame_count += 1
continue
# Frame is in range - create a new frame with reset timestamps.
@@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av(
for pkt in v_out.encode(new_frame):
out.mux(pkt)
src_frame_count += 1
frame_count += 1
# Flush encoder.
@@ -749,15 +752,17 @@ def _copy_and_reindex_videos(
f"videos/{video_key}/to_timestamp"
]
else:
# Build list of time ranges to keep, in sorted order.
# Build list of frame ranges to keep, in sorted order.
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
episodes_to_keep_ranges: list[tuple[float, float]] = []
episodes_to_keep_ranges: list[tuple[int, int]] = []
for old_idx in sorted_keep_episodes:
src_ep = src_dataset.meta.episodes[old_idx]
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
episodes_to_keep_ranges.append((from_ts, to_ts))
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
assert src_ep["length"] == to_frame - from_frame, (
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
)
episodes_to_keep_ranges.append((from_frame, to_frame))
# Use PyAV filters to efficiently re-encode only the desired segments.
assert src_dataset.meta.video_path is not None
@@ -1470,7 +1475,9 @@ def modify_tasks(
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
new_task_df = pd.DataFrame(
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
@@ -1524,7 +1531,7 @@ def modify_tasks(
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
output_dir: Path | None = None,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
@@ -1543,8 +1550,8 @@ def convert_image_to_video_dataset(
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
@@ -1595,6 +1602,7 @@ def convert_image_to_video_dataset(
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
+125 -28
View File
@@ -68,6 +68,7 @@ from lerobot.datasets.utils import (
write_tasks,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
VideoFrame,
concatenate_video_files,
decode_video_frames,
@@ -75,11 +76,11 @@ from lerobot.datasets.video_utils import (
get_safe_default_codec,
get_video_duration_in_s,
get_video_info,
resolve_vcodec,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
class LeRobotDatasetMetadata:
@@ -313,7 +314,7 @@ class LeRobotDatasetMetadata:
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
@@ -545,12 +546,19 @@ class LeRobotDatasetMetadata:
def _encode_video_worker(
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
video_key: str,
episode_index: int,
root: Path,
fps: int,
vcodec: str = "libsvtav1",
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
encode_video_frames(
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
)
shutil.rmtree(img_dir)
return temp_path
@@ -570,6 +578,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -653,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for the README).
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory where the dataset will be downloaded and
stored. If set, all dataset files will be stored directly under this path. If not set, the
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
HF_LEROBOT_HOME environment variable).
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
@@ -683,12 +694,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
encoding is CPU-heavy.
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
streaming encoding. Defaults to 30 (~1s at 30fps).
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
libsvtav1 and 'threads' for h264/hevc.
"""
super().__init__()
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
@@ -700,7 +716,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_indices = None
self.batch_encoding_size = batch_encoding_size
self.episodes_since_last_encoding = 0
self.vcodec = vcodec
self.vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
# Unused attributes
self.image_writer = None
@@ -708,6 +725,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.writer = None
self.latest_episode = None
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
self._streaming_encoder = None
self.root.mkdir(exist_ok=True, parents=True)
@@ -729,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Check if cached dataset contains all requested episodes
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
@@ -749,6 +767,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Initialize streaming encoder for resumed recording
if streaming_encoding and len(self.meta.video_keys) > 0:
self._streaming_encoder = StreamingVideoEncoder(
fps=self.meta.fps,
vcodec=self.vcodec,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
writer = getattr(self, "writer", None)
@@ -808,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hub_api.upload_folder(**upload_kwargs)
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
@@ -1104,6 +1135,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
self._close_writer()
self.meta._close_writer()
if self._streaming_encoder is not None:
self._streaming_encoder.close()
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
@@ -1158,6 +1191,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
# Start streaming encoder on first frame of episode (once, before iterating keys)
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
video_keys=list(self.meta.video_keys),
temp_dir=self.root,
)
# Add frame features to episode_buffer
for key in frame:
if key not in self.features:
@@ -1165,7 +1205,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
)
if self.features[key]["dtype"] in ["image", "video"]:
if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
self._streaming_encoder.feed_frame(key, frame[key])
self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet)
elif self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
@@ -1226,13 +1269,38 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
ep_stats = compute_episode_stats(episode_buffer, self.features)
ep_metadata = self._save_episode_data(episode_buffer)
has_video_keys = len(self.meta.video_keys) > 0
use_streaming = self._streaming_encoder is not None and has_video_keys
use_batched_encoding = self.batch_encoding_size > 1
if has_video_keys and not use_batched_encoding:
if use_streaming:
# Compute stats for non-video features only (video stats come from encoder)
non_video_buffer = {
k: v
for k, v in episode_buffer.items()
if self.features.get(k, {}).get("dtype") not in ("video",)
}
non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"}
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
else:
ep_stats = compute_episode_stats(episode_buffer, self.features)
ep_metadata = self._save_episode_data(episode_buffer)
if use_streaming:
# Finish streaming encoding and collect results
streaming_results = self._streaming_encoder.finish_episode()
for video_key in self.meta.video_keys:
temp_path, video_stats = streaming_results[video_key]
if video_stats is not None:
# Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1)
ep_stats[video_key] = {
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
elif has_video_keys and not use_batched_encoding:
num_cameras = len(self.meta.video_keys)
if parallel_encoding and num_cameras > 1:
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
@@ -1246,6 +1314,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root,
self.fps,
self.vcodec,
self._encoder_threads,
): video_key
for video_key in self.meta.video_keys
}
@@ -1514,6 +1583,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
return metadata
def clear_episode_buffer(self, delete_images: bool = True) -> None:
# Cancel streaming encoder if active
if self._streaming_encoder is not None:
self._streaming_encoder.cancel_episode()
# Clean up image files for the current episode buffer
if delete_images:
# Wait for the async image writer to finish
@@ -1561,7 +1634,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
return _encode_video_worker(
video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads
)
@classmethod
def create(
@@ -1578,10 +1653,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -1590,6 +1668,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
features=features,
root=root,
use_videos=use_videos,
metadata_buffer_size=metadata_buffer_size,
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
@@ -1599,6 +1678,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0
obj.vcodec = vcodec
obj._encoder_threads = encoder_threads
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
@@ -1620,6 +1700,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._lazy_loading = False
obj._recorded_frames = 0
obj._writer_closed_for_reading = False
# Initialize streaming encoder
if streaming_encoding and len(obj.meta.video_keys) > 0:
obj._streaming_encoder = StreamingVideoEncoder(
fps=fps,
vcodec=vcodec,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
else:
obj._streaming_encoder = None
return obj
@@ -1675,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
if extra_keys:
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
+3 -4
View File
@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any, Generic, TypeVar
from typing import Any
import datasets
import numpy as np
@@ -78,8 +78,6 @@ DEFAULT_FEATURES = {
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
T = TypeVar("T")
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
metadata = pq.read_metadata(parquet_path)
@@ -341,6 +339,7 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
tasks.index.name = "task"
return tasks
@@ -1233,7 +1232,7 @@ class LookAheadError(Exception):
pass
class Backtrackable(Generic[T]):
class Backtrackable[T]:
"""
Wrap any iterator/iterable so you can step back up to `history` items
and look ahead up to `lookahead` items.
@@ -36,8 +36,11 @@ Convert a local dataset (works in place):
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht \
--root=/path/to/local/dataset/directory
--root=/path/to/local/dataset/directory \
--push-to-hub=false
N.B. Path semantics (v2): --root is the exact dataset folder containing
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
```
"""
@@ -105,7 +108,7 @@ episodes.jsonl
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
NEW
meta/episodes/chunk-000/episodes_000.parquet
meta/episodes/chunk-000/file_000.parquet
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
-------------------------
OLD
@@ -113,15 +116,16 @@ tasks.jsonl
{"task_index": 1, "task": "Put the blue block in the green bowl"}
NEW
meta/tasks/chunk-000/file_000.parquet
meta/tasks.parquet
task_index | task
-------------------------
OLD
episodes_stats.jsonl
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
NEW
meta/episodes_stats/chunk-000/file_000.parquet
episode_index | mean | std | min | max
meta/episodes/chunk-000/file_000.parquet
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
-------------------------
UPDATE
meta/info.json
@@ -170,7 +174,7 @@ def convert_tasks(root, new_root):
tasks, _ = legacy_load_tasks(root)
task_indices = tasks.keys()
task_strings = tasks.values()
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
write_tasks(df_tasks, new_root)
@@ -201,7 +205,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
image_keys = get_image_keys(root)
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
@@ -211,9 +214,23 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
logging.info(f"Converting data files from {len(ep_paths)} episodes")
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
# Check if we need to start a new file BEFORE creating metadata
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
# Write the accumulated data files
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Move to next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Reset for the next file
size_in_mb = 0
paths_to_cat = []
# Now create metadata with correct chunk/file indices
ep_metadata = {
"episode_index": ep_idx,
"data/chunk_index": chunk_idx,
@@ -224,20 +241,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
size_in_mb += ep_size_in_mb
num_frames += ep_num_frames
episodes_metadata.append(ep_metadata)
ep_idx += 1
if size_in_mb < data_file_size_in_mb:
paths_to_cat.append(ep_path)
continue
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Reset for the next file
size_in_mb = ep_size_in_mb
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
paths_to_cat.append(ep_path)
# Write remaining data if any
if paths_to_cat:
@@ -469,7 +473,7 @@ def convert_dataset(
# Set root based on whether local dataset path is provided
use_local_dataset = False
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
if root.exists():
validate_local_dataset_version(root)
use_local_dataset = True
@@ -553,7 +557,7 @@ if __name__ == "__main__":
"--root",
type=str,
default=None,
help="Local directory to use for downloading/writing the dataset.",
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
)
parser.add_argument(
"--push-to-hub",
+480 -46
View File
@@ -13,25 +13,106 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import glob
import importlib
import logging
import queue
import shutil
import tempfile
import threading
import warnings
from dataclasses import dataclass, field
from fractions import Fraction
from pathlib import Path
from threading import Lock
from typing import Any, ClassVar
import av
import fsspec
import numpy as np
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
"h264_videotoolbox", # macOS
"hevc_videotoolbox", # macOS
"h264_nvenc", # NVIDIA GPU
"hevc_nvenc", # NVIDIA GPU
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
def _get_codec_options(
vcodec: str,
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
) -> dict:
"""Build codec-specific options dict for video encoding."""
options = {}
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
options["g"] = str(g)
# Quality control (codec-specific parameter names)
if crf is not None:
if vcodec in ("h264", "hevc", "libsvtav1"):
options["crf"] = str(crf)
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
quality = max(1, min(100, int(100 - crf * 2)))
options["q:v"] = str(quality)
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
options["rc"] = "constqp"
options["qp"] = str(crf)
elif vcodec in ("h264_vaapi",):
options["qp"] = str(crf)
elif vcodec in ("h264_qsv",):
options["global_quality"] = str(crf)
# Preset (only for libsvtav1)
if vcodec == "libsvtav1":
options["preset"] = str(preset) if preset is not None else "12"
return options
def detect_available_hw_encoders() -> list[str]:
"""Probe PyAV/FFmpeg for available hardware video encoders."""
available = []
for codec_name in HW_ENCODERS:
try:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
pass # nosec B110
return available
def resolve_vcodec(vcodec: str) -> str:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logging.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logging.info(f"Auto-selected video codec: {encoder}")
return encoder
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
@@ -146,16 +227,17 @@ def decode_video_frames_torchvision(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -167,7 +249,11 @@ def decode_video_frames_torchvision(
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
if len(timestamps) != len(closest_frames):
raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})"
)
return closest_frames
@@ -272,15 +358,16 @@ def decode_video_frames_torchcodec(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -309,14 +396,13 @@ def encode_video_frames(
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
preset: int | None = None,
encoder_threads: int | None = None,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
# Check encoder availability
if vcodec not in ["h264", "hevc", "libsvtav1"]:
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
vcodec = resolve_vcodec(vcodec)
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
@@ -347,21 +433,22 @@ def encode_video_frames(
width, height = dummy_image.size
# Define video codec options
video_options = {}
if g is not None:
video_options["g"] = str(g)
if crf is not None:
video_options["crf"] = str(crf)
video_options = _get_codec_options(vcodec, g, crf, preset)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if vcodec == "libsvtav1":
video_options["preset"] = str(preset) if preset is not None else "12"
if encoder_threads is not None:
if vcodec == "libsvtav1":
lp_param = f"lp={encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(encoder_threads)
# Set logging level
if log_level is not None:
@@ -480,6 +567,348 @@ def concatenate_video_files(
Path(tmp_concatenate_path).unlink()
class _CameraEncoderThread(threading.Thread):
"""A thread that encodes video frames streamed via a queue into an MP4 file.
One instance is created per camera per episode. Frames are received as numpy arrays
from the main thread, encoded in real-time using PyAV (which releases the GIL during
encoding), and written to disk. Stats are computed incrementally using
RunningQuantileStats and returned via result_queue.
"""
def __init__(
self,
video_path: Path,
fps: int,
vcodec: str,
pix_fmt: str,
g: int | None,
crf: int | None,
preset: int | None,
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
encoder_threads: int | None = None,
):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.encoder_threads = encoder_threads
def run(self) -> None:
from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
container = None
output_stream = None
stats_tracker = RunningQuantileStats()
frame_count = 0
try:
logging.getLogger("libav").setLevel(av.logging.WARNING)
while True:
try:
frame_data = self.frame_queue.get(timeout=1)
except queue.Empty:
if self.stop_event.is_set():
break
continue
if frame_data is None:
# Sentinel: flush and close
break
# Ensure HWC uint8 numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if frame_data.dtype != np.uint8:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
if container is None:
height, width = frame_data.shape[:2]
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
if self.encoder_threads is not None:
if self.vcodec == "libsvtav1":
lp_param = f"lp={self.encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(self.encoder_threads)
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
output_stream.pix_fmt = self.pix_fmt
output_stream.width = width
output_stream.height = height
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
if packet:
container.mux(packet)
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
img_downsampled = auto_downsample_height_width(img_chw)
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled.shape[0]
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
stats_tracker.update(img_for_stats)
frame_count += 1
# Flush encoder
if output_stream is not None:
packet = output_stream.encode()
if packet:
container.mux(packet)
if container is not None:
container.close()
av.logging.restore_default_callback()
# Get stats and put on result queue
if frame_count >= 2:
stats = stats_tracker.get_statistics()
self.result_queue.put(("ok", stats))
else:
self.result_queue.put(("ok", None))
except Exception as e:
logging.error(f"Encoder thread error: {e}")
if container is not None:
with contextlib.suppress(Exception):
container.close()
self.result_queue.put(("error", str(e)))
class StreamingVideoEncoder:
"""Manages per-camera encoder threads for real-time video encoding during recording.
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
this class streams frames directly to encoder threads, eliminating the
PNG round-trip and making save_episode() near-instant.
Uses threading instead of multiprocessing to avoid the overhead of pickling large
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
so encoding runs in parallel with the main recording loop.
"""
def __init__(
self,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
self.fps = fps
self.vcodec = resolve_vcodec(vcodec)
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self.queue_maxsize = queue_maxsize
self.encoder_threads = encoder_threads
self._frame_queues: dict[str, queue.Queue] = {}
self._result_queues: dict[str, queue.Queue] = {}
self._threads: dict[str, _CameraEncoderThread] = {}
self._stop_events: dict[str, threading.Event] = {}
self._video_paths: dict[str, Path] = {}
self._dropped_frames: dict[str, int] = {}
self._episode_active = False
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
"""Start encoder threads for a new episode.
Args:
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
temp_dir: Base directory for temporary MP4 files
"""
if self._episode_active:
self.cancel_episode()
self._dropped_frames.clear()
for video_key in video_keys:
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
result_queue: queue.Queue = queue.Queue(maxsize=1)
stop_event = threading.Event()
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
vcodec=self.vcodec,
pix_fmt=self.pix_fmt,
g=self.g,
crf=self.crf,
preset=self.preset,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
encoder_threads=self.encoder_threads,
)
encoder_thread.start()
self._frame_queues[video_key] = frame_queue
self._result_queues[video_key] = result_queue
self._threads[video_key] = encoder_thread
self._stop_events[video_key] = stop_event
self._video_paths[video_key] = video_path
self._episode_active = True
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
"""Feed a frame to the encoder for a specific camera.
A copy of the image is made before enqueueing to prevent race conditions
with camera drivers that may reuse buffers. If the encoder queue is full
(encoder can't keep up), the frame is dropped with a warning instead of
crashing the recording session.
Args:
video_key: The video feature key
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
Raises:
RuntimeError: If the encoder thread has crashed
"""
if not self._episode_active:
raise RuntimeError("No active episode. Call start_episode() first.")
thread = self._threads[video_key]
if not thread.is_alive():
# Check for error
try:
status, msg = self._result_queues[video_key].get_nowait()
if status == "error":
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
except queue.Empty:
pass
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
try:
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
except queue.Full:
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
count = self._dropped_frames[video_key]
# Log periodically to avoid spam (1st, then every 10th)
if count == 1 or count % 10 == 0:
logging.warning(
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
)
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
"""Finish encoding the current episode.
Sends sentinel values, waits for encoder threads to complete,
and collects results.
Returns:
Dict mapping video_key to (mp4_path, stats_dict_or_None)
"""
if not self._episode_active:
raise RuntimeError("No active episode to finish.")
results = {}
# Report dropped frames
for video_key, count in self._dropped_frames.items():
if count > 0:
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
# Send sentinel to all queues
for video_key in self._frame_queues:
self._frame_queues[video_key].put(None)
# Wait for all threads and collect results
for video_key in self._threads:
self._threads[video_key].join(timeout=120)
if self._threads[video_key].is_alive():
logging.error(f"Encoder thread for {video_key} did not finish in time")
self._stop_events[video_key].set()
self._threads[video_key].join(timeout=5)
results[video_key] = (self._video_paths[video_key], None)
continue
try:
status, data = self._result_queues[video_key].get(timeout=5)
if status == "error":
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
results[video_key] = (self._video_paths[video_key], data)
except queue.Empty:
logging.error(f"No result from encoder thread for {video_key}")
results[video_key] = (self._video_paths[video_key], None)
self._cleanup()
self._episode_active = False
return results
def cancel_episode(self) -> None:
"""Cancel the current episode, stopping encoder threads and cleaning up."""
if not self._episode_active:
return
# Signal all threads to stop
for video_key in self._stop_events:
self._stop_events[video_key].set()
# Wait for threads to finish
for video_key in self._threads:
self._threads[video_key].join(timeout=5)
# Clean up temp MP4 files
video_path = self._video_paths.get(video_key)
if video_path is not None and video_path.exists():
shutil.rmtree(str(video_path.parent), ignore_errors=True)
self._cleanup()
self._episode_active = False
def close(self) -> None:
"""Close the encoder, canceling any in-progress episode."""
if self._episode_active:
self.cancel_episode()
def _cleanup(self) -> None:
"""Clean up queues and thread tracking dicts."""
for q in self._frame_queues.values():
with contextlib.suppress(Exception):
while not q.empty():
q.get_nowait()
self._frame_queues.clear()
self._result_queues.clear()
self._threads.clear()
self._stop_events.clear()
self._video_paths.clear()
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
@@ -514,7 +943,7 @@ with warnings.catch_warnings():
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
logging.getLogger("libav").setLevel(av.logging.WARNING)
# Getting audio stream information
audio_info = {}
@@ -546,7 +975,7 @@ def get_audio_info(video_path: Path | str) -> dict:
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.ERROR)
logging.getLogger("libav").setLevel(av.logging.WARNING)
# Getting video stream information
video_info = {}
@@ -632,8 +1061,15 @@ class VideoEncodingManager:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Handle any remaining episodes that haven't been batch encoded
if self.dataset.episodes_since_last_encoding > 0:
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
if streaming_encoder is not None:
# Handle streaming encoder cleanup
if exc_type is not None:
streaming_encoder.cancel_episode()
streaming_encoder.close()
elif self.dataset.episodes_since_last_encoding > 0:
# Handle any remaining episodes that haven't been batch encoded
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
@@ -650,8 +1086,8 @@ class VideoEncodingManager:
# Finalize the dataset to properly close all writers
self.dataset.finalize()
# Clean up episode images if recording was interrupted
if exc_type is not None:
# Clean up episode images if recording was interrupted (only for non-streaming mode)
if exc_type is not None and streaming_encoder is None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
@@ -665,14 +1101,12 @@ class VideoEncodingManager:
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"
# Check for any remaining PNG files
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
# Only remove the images directory if no PNG files remain
if img_dir.exists():
if img_dir.exists():
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception
+4 -4
View File
@@ -29,7 +29,7 @@ from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pprint import pformat
from typing import Protocol, TypeAlias
from typing import Protocol
import serial
from deepdiff import DeepDiff
@@ -38,8 +38,8 @@ from tqdm import tqdm
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
type NameOrID = str | int
type Value = int | float
logger = logging.getLogger(__name__)
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
# Backward compatibility alias
MotorsBus: TypeAlias = SerialMotorsBus
MotorsBus = SerialMotorsBus
@@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,3 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .robstride import RobstrideMotorsBus
from .tables import *
File diff suppressed because it is too large Load Diff
+120
View File
@@ -0,0 +1,120 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
# Motor type definitions
class MotorType(IntEnum):
O0 = 0
O1 = 1
O2 = 2
O3 = 3
O4 = 4
O5 = 5
ELO5 = 6
O6 = 7
class CommMode(IntEnum):
PrivateProtocole = 0
CANopen = 1
MIT = 2
# Control modes
class ControlMode(IntEnum):
MIT = 0
POS_VEL = 1
VEL = 2
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS: dict[MotorType, tuple[float, float, float]] = {
MotorType.O0: (12.57, 33, 14),
MotorType.O1: (12.57, 44, 17),
MotorType.O2: (12.57, 33, 20),
MotorType.O3: (12.57, 33, 60),
MotorType.O4: (12.57, 33, 120),
MotorType.O5: (12.57, 50, 5.5),
MotorType.ELO5: (12.57, 50, 6),
MotorType.O6: (112.5, 50, 36),
}
# Motor model names
MODEL_NAMES = {
MotorType.O0: "O0",
MotorType.O1: "O1",
MotorType.O2: "O2",
MotorType.O3: "O3",
MotorType.O4: "O4",
MotorType.O5: "O5",
MotorType.ELO5: "ELO5",
MotorType.O6: "O6",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"O0": 65536,
"O1": 65536,
"O2": 65536,
"O3": 65536,
"O4": 65536,
"O5": 65536,
"ELO5": 65536,
"O6": 65536,
}
# CAN baudrates supported by Robstride motors
AVAILABLE_BAUDRATES = [
1000000, # 4: 1 mbps (default)
]
DEFAULT_BAUDRATE = 1000000
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 0 # disabled by default, otherwise 20000 is 1s
# Data that should be normalized
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_CLEAR_FAULT = 0xFB
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.001
PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02
+2
View File
@@ -15,6 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
@@ -28,6 +29,7 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"PI0FastConfig",
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
backbone. If None, no resizing is done and the original image resolution is used.
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
this is computed automatically. Can also be set directly for legacy configs that use
crop-only (without resize). If None and no derivation applies, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center
crop in eval mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
resize_shape: tuple[int, int] | None = None
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
@@ -139,6 +147,10 @@ class DiffusionConfig(PreTrainedConfig):
# Inference
num_inference_steps: int | None = None
# Optimization
compile_model: bool = False
compile_mode: str = "reduce-overhead"
# Loss computation
do_mask_loss_for_padding: bool = False
@@ -171,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}."
)
if self.resize_shape is not None and (
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
):
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
if not (0 < self.crop_ratio <= 1.0):
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
if self.resize_shape is not None:
if self.crop_ratio < 1.0:
self.crop_shape = (
int(self.resize_shape[0] * self.crop_ratio),
int(self.resize_shape[1] * self.crop_ratio),
)
else:
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
self.crop_shape = None
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
@@ -198,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
if self.resize_shape is None and self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for `{key}`."
)
# Check that all input images have the same shape.
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
for key in self.config.image_features:
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
@@ -446,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None
crop_shape = config.crop_shape
if crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -477,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
if config.crop_shape is not None:
dummy_shape_h_w = config.crop_shape
elif config.resize_shape is not None:
dummy_shape_h_w = config.resize_shape
else:
dummy_shape_h_w = images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
@@ -499,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
# Preprocess: resize if configured, then crop if configured.
if self.resize is not None:
x = self.resize(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
+21 -6
View File
@@ -18,10 +18,9 @@ from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, Unpack
import torch
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
@@ -32,6 +31,7 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -67,8 +67,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -87,6 +86,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "multi_task_dit":
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
return MultiTaskDiTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
@@ -147,8 +150,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -163,6 +166,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "multi_task_dit":
return MultiTaskDiTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
@@ -309,6 +314,16 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MultiTaskDiTConfig):
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
processors = make_multi_task_dit_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
@@ -4,17 +4,16 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import annotations
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from typing import Optional
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
ImagesKwargs,
group_images_by_shape,
reorder_images,
)
@@ -77,7 +76,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
@@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _resize_for_patching(
self,
image: "torch.Tensor",
image: torch.Tensor,
target_resolution: tuple,
interpolation: "F.InterpolationMode",
interpolation: F.InterpolationMode,
input_data_format: ChannelDimension,
) -> "torch.Tensor":
) -> torch.Tensor:
"""
Resizes an image to a target resolution while maintaining aspect ratio.
@@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
return best_ratio
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
) -> torch.Tensor:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
@@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _get_image_patches(
self,
image: "torch.Tensor",
image: torch.Tensor,
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: "F.InterpolationMode",
interpolation: F.InterpolationMode,
pad_during_tiling: bool,
) -> list["torch.Tensor"]:
) -> list[torch.Tensor]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
@@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _pad_for_batching(
self,
pixel_values: list["torch.Tensor"],
) -> list["torch.Tensor"]:
pixel_values: list[torch.Tensor],
) -> list[torch.Tensor]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
@@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _preprocess(
self,
images: list["torch.Tensor"],
images: list[torch.Tensor],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: Optional["F.InterpolationMode"],
interpolation: F.InterpolationMode | None,
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
@@ -0,0 +1,37 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```
@@ -1,4 +1,6 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
from .configuration_multi_task_dit import MultiTaskDiTConfig
from .modeling_multi_task_dit import MultiTaskDiTPolicy
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]
@@ -0,0 +1,256 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass
class MultiTaskDiTConfig(PreTrainedConfig):
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
A transformer-based policy that supports both diffusion and flow matching objectives
for multi-task robot learning with text and vision conditioning.
"""
n_obs_steps: int = 2 # Number of observation steps for temporal context
horizon: int = 32 # Number of action steps to predict
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
# Objective Selection
objective: str = "diffusion" # "diffusion" or "flow_matching"
# --- Diffusion-specific (used when objective="diffusion") ---
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # Number of diffusion timesteps
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
beta_start: float = 0.0001 # Starting noise level
beta_end: float = 0.02 # Ending noise level
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
clip_sample: bool = True # Clip samples during denoising
clip_sample_range: float = 1.0 # Clipping range [-x, x]
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
# --- Flow Matching-specific (used when objective="flow_matching") ---
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
num_integration_steps: int = 100 # ODE integration steps at inference
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
# Transformer Architecture
hidden_dim: int = 512 # Transformer hidden dimension
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Use absolute positional encoding
timestep_embed_dim: int = 256 # Timestep embedding dimension
use_rope: bool = True # Use Rotary Position Embedding
rope_base: float = 10000.0 # RoPE base frequency
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
image_crop_is_random: bool = True # Random crop during training, center at inference
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Training/Optimizer
optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0
do_mask_loss_for_padding: bool = False
# Auto-calculated
drop_n_last_frames: int | None = None
def __post_init__(self):
super().__post_init__()
if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
self._validate()
def _validate(self):
"""Validate configuration parameters."""
# Objective validation
if self.objective not in ["diffusion", "flow_matching"]:
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
# Transformer validation
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
# Vision encoder validation
if "clip" not in self.vision_encoder_name.lower():
raise ValueError(
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
)
if (
self.image_resize_shape
and self.image_crop_shape
and (
self.image_crop_shape[0] > self.image_resize_shape[0]
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
logging.warning(
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
self.image_crop_shape,
self.image_resize_shape,
)
self.image_crop_shape = None
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
raise ValueError(
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
)
# Objective-specific validation
if self.objective == "diffusion":
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
elif self.objective == "flow_matching":
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
)
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
if self.timestep_sampling_strategy == "beta":
if not (0.0 < self.timestep_sampling_s <= 1.0):
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
if self.timestep_sampling_alpha <= 0:
raise ValueError("timestep_sampling_alpha must be positive")
if self.timestep_sampling_beta <= 0:
raise ValueError("timestep_sampling_beta must be positive")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# If the configured crop doesn't fit, disable cropping instead of erroring.
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
if self.image_crop_shape is not None:
for key, image_ft in self.image_features.items():
# image_ft.shape is (C, H, W)
effective_h, effective_w = (
self.image_resize_shape
if self.image_resize_shape is not None
else (image_ft.shape[1], image_ft.shape[2])
)
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
logging.warning(
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
self.image_crop_shape,
effective_h,
effective_w,
key,
)
self.image_crop_shape = None
break
if len(self.image_features) > 0:
first_key, first_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_ft.shape:
raise ValueError(
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
)
@property
def is_diffusion(self) -> bool:
return self.objective == "diffusion"
@property
def is_flow_matching(self) -> bool:
return self.objective == "flow_matching"
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None
@@ -0,0 +1,803 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Task Diffusion Transformer (DiT) Policy
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
Supports both diffusion and flow matching objectives for action generation.
References:
- https://arxiv.org/abs/2507.05331
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
"""
import math
from collections import deque
from typing import TYPE_CHECKING
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import CLIPTextModel, CLIPVisionModel
else:
CLIPTextModel = None
CLIPVisionModel = None
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
# -- Policy --
class MultiTaskDiTPolicy(PreTrainedPolicy):
config_class = MultiTaskDiTConfig
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
super().__init__(config)
config.validate_features()
self.config = config
self._queues = None
self.observation_encoder = ObservationEncoder(config)
conditioning_dim = self.observation_encoder.conditioning_dim
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
action_dim = config.action_feature.shape[0]
horizon = config.horizon
if config.is_diffusion:
self.objective = DiffusionObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
elif config.is_flow_matching:
self.objective = FlowMatchingObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
else:
raise ValueError(f"Unsupported objective: {config.objective}")
self.reset()
def get_optim_params(self) -> list:
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
non_vision_params = []
vision_encoder_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if "observation_encoder.vision_encoder" in name:
vision_encoder_params.append(param)
else:
non_vision_params.append(param)
return [
{"params": non_vision_params},
{
"params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
},
]
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
actions = self._generate_actions(batch)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Prepare batch by stacking image features if needed."""
if self.config.image_features:
batch = dict(batch) # shallow copy to avoid modifying original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch:
batch = dict(batch) # shallow copy to avoid modifying original
batch.pop(ACTION)
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training"""
batch = self._prepare_batch(batch)
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
return loss, None
# -- Observation Encoders --
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
self.model = CLIPVisionModel.from_pretrained(self.model_name)
self.num_non_spatial_tokens = 1
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token."""
outputs = self.model(pixel_values=x, output_hidden_states=False)
cls_token = outputs.last_hidden_state[:, 0]
b, embed_dim = cls_token.shape
return cls_token.reshape(b, embed_dim, 1, 1)
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer.
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
pipeline to see how the tokenization is handled.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Encode pre-tokenized text to feature vectors."""
# Ensure inputs are on the same device as the model
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
clip_features = outputs.pooler_output
return self.projection(clip_features)
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_rgb_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, config):
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
total_dim = 0
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w
total_dim += spatial_feature_dim * self.num_cameras
total_dim += self.robot_state_dim
total_dim += self.text_dim
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES]
if len(images.shape) == 5:
images = images.unsqueeze(1)
if self.config.use_separate_rgb_encoder_per_camera:
camera_features = []
for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx]
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
cam_images_flat = self._apply_preprocessing(cam_images_flat)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
cam_visual_features = cam_features.flatten(start_dim=1)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
text_features = self.text_encoder(input_ids, attention_mask)
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features)
combined_features = torch.cat(conditioning_feats, dim=-1)
return combined_features.flatten(start_dim=1)
# -- Transformer Components --
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero."""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers."""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
B, T, _ = x.shape # noqa: N806
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.rope(q, k)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
)
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
return self.out_proj(attn_out)
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero."""
def __init__(
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
super().__init__()
self.config = config
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if config.use_positional_encoding:
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
)
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
self._initialize_weights()
def _initialize_weights(self):
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep)
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
hidden_seq = self.input_proj(x)
if self.pos_embedding is not None:
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features)
return self.output_proj(hidden_seq)
# -- Objectives --
class DiffusionObjective(nn.Module):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
self.num_inference_steps = (
config.num_inference_steps
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch[ACTION]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"]
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(nn.Module):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling_strategy == "beta":
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
data = batch[ACTION]
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1)
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"]
loss = loss * valid_mask.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x
@@ -0,0 +1,105 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_multi_task_dit_pre_post_processors(
config: MultiTaskDiTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Adding a batch dimension.
3. Tokenizing the language task description (if present).
4. Moving the data to the specified device.
5. Normalizing the input and output features based on dataset statistics.
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output features to their original scale.
2. Moving the data to the CPU.
Args:
config: The configuration object for the Multi-Task DiT policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.text_encoder_name,
padding=config.tokenizer_padding,
padding_side=config.tokenizer_padding_side,
max_length=config.tokenizer_max_length,
truncation=config.tokenizer_truncation,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+68 -54
View File
@@ -15,16 +15,16 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -32,13 +32,21 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
@@ -191,7 +199,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
resized_images = resized_images.clamp(0.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -202,7 +210,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
constant_value = 0 if images.dtype == torch.uint8 else 0.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -221,14 +229,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.language_model, gemma_expert.model]
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -254,10 +262,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -265,7 +273,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -277,15 +285,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -358,7 +366,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -366,7 +374,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.torch_dtype = "float32"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -377,13 +385,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
torch_dtype="float32",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -398,10 +406,11 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -413,8 +422,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -424,15 +433,23 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
self.paligemma.model.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
@@ -446,7 +463,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.language_model.forward(
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -470,7 +487,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.language_model, self.gemma_expert.model]
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -510,7 +527,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -576,29 +593,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
@@ -760,7 +767,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -834,7 +841,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -908,6 +915,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -997,14 +1005,12 @@ class PI0Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Now manually load and remap the state dict
# Load state dict (expects keys with "model." prefix)
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -1012,7 +1018,7 @@ class PI0Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -1025,7 +1031,7 @@ class PI0Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1070,7 +1076,7 @@ class PI0Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not remap state dict keys: {e}")
print(f"Warning: Could not load state dict: {e}")
return model
@@ -1120,6 +1126,14 @@ class PI0Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
+71 -60
View File
@@ -15,16 +15,16 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -32,14 +32,20 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -92,10 +98,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
return dist.sample((bsize,)).to(device)
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
@@ -189,7 +196,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
resized_images = resized_images.clamp(0.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -200,7 +207,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
constant_value = 0 if images.dtype == torch.uint8 else 0.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -219,14 +226,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.language_model, gemma_expert.model]
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -252,10 +259,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -263,7 +270,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -275,15 +282,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -356,7 +363,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -364,7 +371,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.torch_dtype = "float32"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -375,13 +382,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
torch_dtype="float32",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -396,10 +403,11 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
params_to_keep_float32 = [
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -411,8 +419,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -422,15 +430,23 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
self.paligemma.model.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
@@ -444,7 +460,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.language_model.forward(
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -468,7 +484,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.language_model, self.gemma_expert.model]
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -508,7 +524,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -573,29 +589,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
@@ -737,7 +743,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -808,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -880,6 +886,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -969,14 +976,12 @@ class PI05Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Now manually load and remap the state dict
# Load state dict (expects keys with "model." prefix)
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -984,7 +989,7 @@ class PI05Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -997,7 +1002,7 @@ class PI05Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1009,8 +1014,6 @@ class PI05Policy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -1044,7 +1047,7 @@ class PI05Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not remap state dict keys: {e}")
print(f"Warning: Could not load state dict: {e}")
return model
@@ -1098,6 +1101,14 @@ class PI05Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
@@ -23,7 +23,6 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
tokenizer_max_length: int = 200 # see openpi `__post_init__`
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
action_tokenizer_name: str = "physical-intelligence/fast"
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
temperature: float = 0.0
max_decoding_steps: int = 256
fast_skip_tokens: int = 128
@@ -19,13 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _scipy_available, _transformers_available
@@ -38,11 +37,16 @@ else:
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaModel,
)
else:
CONFIG_MAPPING = None
PaliGemmaForConditionalGeneration = None
AutoTokenizer = None
PiGemmaModel = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
@@ -121,7 +125,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
resized_images = resized_images.clamp(0.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -132,7 +136,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
constant_value = 0 if images.dtype == torch.uint8 else 0.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -206,16 +210,22 @@ class PI0FastPaliGemma(nn.Module):
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.torch_dtype = "float32"
vlm_config_hf.vision_config.dtype = "float32"
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
if use_adarms[0]:
text_config = self.paligemma.config.text_config
self.paligemma.model.language_model = PiGemmaModel(text_config)
self.to_bfloat16_for_selected_params(precision)
@@ -228,10 +238,11 @@ class PI0FastPaliGemma(nn.Module):
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -242,10 +253,18 @@ class PI0FastPaliGemma(nn.Module):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
@@ -259,7 +278,7 @@ class PI0FastPaliGemma(nn.Module):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.language_model.forward(
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -306,24 +325,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
@@ -332,8 +341,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
# Call the proper gradient_checkpointing_disable() method
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
def _apply_checkpoint(self, func, *args, **kwargs):
@@ -523,7 +532,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Convert embeddings to bfloat16 if needed
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -616,7 +625,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -714,7 +723,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Ensure correct precision (bfloat16/float32)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -897,14 +906,12 @@ class PI0FastPolicy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Now manually load and remap the state dict
# Load state dict (expects keys with "model." prefix)
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -912,7 +919,7 @@ class PI0FastPolicy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -925,8 +932,9 @@ class PI0FastPolicy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
@@ -936,8 +944,6 @@ class PI0FastPolicy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -971,7 +977,7 @@ class PI0FastPolicy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not remap state dict keys: {e}")
print(f"Warning: Could not load state dict: {e}")
return model
@@ -23,7 +23,6 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
@@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
+363
View File
@@ -0,0 +1,363 @@
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch import nn
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaConfig,
GemmaForCausalLM,
GemmaMLP,
GemmaModel,
)
from transformers.models.paligemma.modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaModel,
)
else:
GemmaAttention = None
GemmaConfig = None
GemmaForCausalLM = None
GemmaMLP = None
GemmaModel = None
PaliGemmaModel = None
PaliGemmaForConditionalGeneration = None
DynamicCache = None
GradientCheckpointingLayer = None
BaseModelOutputWithPast = None
create_causal_mask = None
def _gated_residual(
x: torch.Tensor | None,
y: torch.Tensor | None,
gate: torch.Tensor | None,
) -> torch.Tensor | None:
"""Gated residual: x + y when gate is None, else x + y * gate."""
if x is None and y is None:
return None
if x is None or y is None:
return x if x is not None else y
if gate is None:
return x + y
return x + y * gate
def layernorm_forward(
layernorm: nn.Module,
x: torch.Tensor,
cond: torch.Tensor | None = None,
):
"""
call layernorm and return hidden states and gate
if cond is not None, use conditional norm
otherwise, use normal gemma norm
"""
if cond is not None:
return layernorm(x, cond=cond)
else:
return layernorm(x)
class PiGemmaRMSNorm(nn.Module):
"""
Adaptive RMSNorm for PI Gemma (AdaRMS).
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
"""
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
self.eps = eps
self.dim = dim
self.cond_dim = cond_dim
if cond_dim is not None:
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
nn.init.zeros_(self.dense.weight)
else:
self.weight = nn.Parameter(torch.zeros(dim))
self.dense = None
def _norm(self, x):
# Compute variance in float32 (like the source implementation)
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
# Compute normalization in float32
normed_inputs = x * torch.rsqrt(var + self.eps)
return normed_inputs
def forward(
self,
x: torch.Tensor,
cond: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
dtype = x.dtype
normed = self._norm(x)
if cond is None or self.dense is None:
normed = normed * (1.0 + self.weight.float())
return normed.type_as(x), None
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
modulation = self.dense(cond)
if len(x.shape) == 3:
modulation = modulation.unsqueeze(1)
scale, shift, gate = modulation.chunk(3, dim=-1)
normed = normed * (1 + scale.float()) + shift.float()
return normed.to(dtype), gate.to(dtype)
def extra_repr(self) -> str:
if self.dense is not None:
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
return f"dim={self.dim}, eps={self.eps}"
def _get_pi_gemma_decoder_layer_base():
"""base for PiGemmaDecoderLayer"""
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
cond_dim = (
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
)
self.input_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
self.post_attention_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
hidden_states, _ = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = _gated_residual(residual, hidden_states, gate)
residual = hidden_states
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
hidden_states = self.mlp(hidden_states)
hidden_states = _gated_residual(residual, hidden_states, gate)
return hidden_states
return _PiGemmaDecoderLayerBase
class PiGemmaModel(GemmaModel): # type: ignore[misc]
"""
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
# if not getattr(config, "use_adarms", False):
# return
cond_dim = getattr(config, "adarms_cond_dim", None)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutputWithPast:
"""
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
import logging
logging.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = layer_outputs
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states, _ = self.norm(hidden_states, adarms_cond)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
"""
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = PiGemmaModel(config)
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
def __init__(self, config):
super().__init__(config)
self.language_model = PiGemmaModel(config.text_config)
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
def __init__(self, config):
super().__init__(config)
self.model = PaliGemmaModelWithPiGemma(config)
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model
__all__ = [
"PiGemmaModel",
"PiGemmaForCausalLM",
"PiGemmaRMSNorm",
"_gated_residual",
"layernorm_forward",
"PaliGemmaModelWithPiGemma",
"PaliGemmaForConditionalGenerationWithPiGemma",
]
+1 -2
View File
@@ -19,7 +19,7 @@ import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypedDict, TypeVar
from typing import TypedDict, TypeVar, Unpack
import packaging
import safetensors
@@ -28,7 +28,6 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@@ -1,156 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) policy configuration.
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
(Xu et al., Physical Intelligence, 2026)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.policies.sac.configuration_sac import ActorLearnerConfig, ConcurrencyConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
@dataclass
class RLTokenConfig:
"""Configuration for the RL-token encoder/decoder transformer."""
input_dim: int = 2048
rl_token_dim: int = 2048
num_encoder_layers: int = 2
num_decoder_layers: int = 2
num_heads: int = 8
ff_dim: int = 2048
dropout: float = 0.0
@dataclass
class RLTActorConfig:
"""Configuration for the lightweight RL actor MLP."""
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
std: float = 0.1
@dataclass
class RLTCriticConfig:
"""Configuration for the RLT critic MLP."""
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
@PreTrainedConfig.register_subclass("rlt")
@dataclass
class RLTConfig(PreTrainedConfig):
"""Configuration for the RLT (RL Token) policy.
RLT adds an RL-token encoder/decoder to a frozen VLA backbone, then trains
a lightweight actor-critic head using the RL token as state representation.
The frozen VLA also provides reference action chunks that the actor refines.
"""
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
OBS_STATE: {"min": [0.0], "max": [1.0]},
ACTION: {"min": [0.0], "max": [1.0]},
}
)
# ── Device ──
device: str = "cuda"
storage_device: str = "cpu"
# ── VLA backbone ──
vla_checkpoint: str | None = None
# ── RL-token ──
rl_token: RLTokenConfig = field(default_factory=RLTokenConfig)
# ── Actor / Critic heads ──
actor: RLTActorConfig = field(default_factory=RLTActorConfig)
critic: RLTCriticConfig = field(default_factory=RLTCriticConfig)
# ── Action chunks ──
chunk_size: int = 10
vla_chunk_size: int = 50
# ── Training parameters ──
online_steps: int = 50000
offline_steps: int = 5000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
online_step_before_learning: int = 500
warmup_steps: int = 500
async_prefetch: bool = False
# ── Algorithm hyperparameters ──
utd_ratio: int = 5
policy_update_freq: int = 2
discount: float = 0.99
critic_lr: float = 3e-4
actor_lr: float = 3e-4
rl_token_lr: float = 1e-4
tau: float = 0.005
clip_grad_norm: float = 10.0
num_critics: int = 2
bc_reg_coeff: float = 0.1
ref_dropout: float = 0.5
chunk_stride: int = 2
vla_finetune_weight: float = 0.0
# ── Distributed ──
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
def __post_init__(self):
super().__post_init__()
def get_optimizer_preset(self):
return None
def get_scheduler_preset(self):
return None
def validate_features(self) -> None:
if ACTION not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def observation_delta_indices(self) -> list | None:
return None
@property
def action_delta_indices(self) -> list | None:
return None
@property
def reward_delta_indices(self) -> None:
return None
-318
View File
@@ -1,318 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) policy networks.
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
(Xu et al., Physical Intelligence, 2026)
Architecture:
- RLTokenEncoder: compresses VLA token embeddings into a single compact RL token
- RLTokenDecoder: reconstructs VLA embeddings from the RL token (Stage 1 training only)
- RLTActor: refines VLA reference action chunks conditioned on (z_rl, proprioception, ref_action)
- RLTCritic: Q(x, action_chunk) where x = (z_rl, proprioception)
- RLTPolicy: bundles RL-token modules + actor into a PreTrainedPolicy for inference
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rlt.configuration_rlt import RLTConfig
# ── Building blocks ──────────────────────────────────────────────────
class MLP(nn.Module):
"""Simple feedforward network with ReLU activations."""
def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int):
super().__init__()
layers: list[nn.Module] = []
prev = input_dim
for h in hidden_dims:
layers.append(nn.Linear(prev, h))
layers.append(nn.ReLU())
prev = h
layers.append(nn.Linear(prev, output_dim))
self.net = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
# ── RL Token Encoder ─────────────────────────────────────────────────
class RLTokenEncoder(nn.Module):
"""Compress VLA token embeddings into a single RL token via a small transformer.
Appends a learnable ``e_rl`` embedding to the VLA token sequence, processes
through transformer encoder layers, and returns the output at the ``e_rl``
position as the RL token ``z_rl``.
Paper Eq. 1: z_rl = g_phi([z_{1:M}, e_rl])_{M+1}
"""
def __init__(
self,
input_dim: int,
rl_token_dim: int,
num_layers: int,
num_heads: int,
ff_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.rl_token_dim = rl_token_dim
self.e_rl = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)
if input_dim != rl_token_dim:
self.input_proj = nn.Linear(input_dim, rl_token_dim)
else:
self.input_proj = nn.Identity()
encoder_layer = nn.TransformerEncoderLayer(
d_model=rl_token_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, z_vla: Tensor) -> Tensor:
"""
Args:
z_vla: VLA token embeddings, shape ``(B, M, D)``.
Returns:
RL token ``z_rl``, shape ``(B, rl_token_dim)``.
"""
batch_size = z_vla.shape[0]
e_rl = self.e_rl.expand(batch_size, -1, -1)
seq = torch.cat([z_vla, e_rl], dim=1) # (B, M+1, D)
seq = self.input_proj(seq)
out = self.transformer(seq)
z_rl = out[:, -1, :] # output at e_rl position
return z_rl
# ── RL Token Decoder ─────────────────────────────────────────────────
class RLTokenDecoder(nn.Module):
"""Autoregressively reconstruct VLA embeddings from z_rl.
Used only during Stage 1 (offline RL-token training).
Paper Eq. 2: L_ro = E[sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2]
"""
def __init__(
self,
rl_token_dim: int,
output_dim: int,
num_layers: int,
num_heads: int,
ff_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.output_dim = output_dim
if rl_token_dim != output_dim:
self.rl_proj = nn.Linear(rl_token_dim, output_dim)
else:
self.rl_proj = nn.Identity()
decoder_layer = nn.TransformerDecoderLayer(
d_model=output_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.output_head = nn.Linear(output_dim, output_dim)
def forward(self, z_rl: Tensor, z_vla_stopped: Tensor) -> Tensor:
"""
Args:
z_rl: RL token, shape ``(B, D_rl)``.
z_vla_stopped: Stop-gradient VLA embeddings, shape ``(B, M, D)``.
Returns:
Reconstructed embeddings, shape ``(B, M, D)``.
"""
seq_len = z_vla_stopped.shape[1]
z_rl_proj = self.rl_proj(z_rl).unsqueeze(1)
target = torch.cat([z_rl_proj, z_vla_stopped[:, :-1, :]], dim=1)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=z_rl.device)
decoded = self.transformer(
tgt=target,
memory=z_rl_proj,
tgt_mask=causal_mask,
)
return self.output_head(decoded) # (B, M, D)
# ── Actor ────────────────────────────────────────────────────────────
class RLTActor(nn.Module):
"""Lightweight actor that refines VLA reference action chunks.
Paper Eq. 4: pi_theta(a_{1:C} | x, a_tilde_{1:C}) = N(mu_theta(x, a_tilde), sigma^2 I)
The actor is conditioned on both the RL state and the VLA's proposed action
chunk, acting as a "VLA-guided action editor".
"""
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int], std: float = 0.1):
super().__init__()
input_dim = state_dim + action_chunk_dim
self.net = MLP(input_dim, hidden_dims, action_chunk_dim)
self.log_std = math.log(std)
def forward(self, state: Tensor, ref_action_chunk: Tensor) -> Tensor:
"""Return the mean action chunk.
Args:
state: RL state ``x = (z_rl, proprioception)``, shape ``(B, state_dim)``.
ref_action_chunk: Flattened VLA reference chunk, shape ``(B, C*d)``.
Returns:
Refined action chunk (mean), shape ``(B, C*d)``.
"""
x = torch.cat([state, ref_action_chunk], dim=-1)
return self.net(x)
def sample(self, state: Tensor, ref_action_chunk: Tensor) -> tuple[Tensor, Tensor]:
"""Sample an action and return (action, log_prob)."""
mean = self.forward(state, ref_action_chunk)
std = math.exp(self.log_std)
noise = torch.randn_like(mean) * std
action = mean + noise
log_prob = -0.5 * (noise / std).pow(2).sum(dim=-1) - mean.shape[-1] * math.log(
std * math.sqrt(2 * math.pi)
)
return action, log_prob
# ── Policy (inference bundle) ────────────────────────────────────────
class RLTPolicy(PreTrainedPolicy):
"""RLT policy — bundles the RL-token encoder and actor for inference.
The frozen VLA backbone is **not** part of this module; it is loaded
separately and its embeddings / reference actions are passed in via the
observation dict (populated by the actor process or a preprocessor).
During training, the :class:`RLTAlgorithm` holds the critic, target networks,
and optimizers. This class only contains what is needed for ``select_action``.
"""
name = "rlt"
config_class = RLTConfig
def __init__(self, config: RLTConfig, dataset_stats=None):
super().__init__(config, dataset_stats)
action_dim = config.output_features["action"].shape[0]
action_chunk_dim = config.chunk_size * action_dim
prop_feature = config.input_features.get("observation.state", None)
proprioception_dim = prop_feature.shape[0] if prop_feature is not None else 0
state_dim = config.rl_token.rl_token_dim + proprioception_dim
# RL-token encoder (frozen after Stage 1)
self.rl_token_encoder = RLTokenEncoder(
input_dim=config.rl_token.input_dim,
rl_token_dim=config.rl_token.rl_token_dim,
num_layers=config.rl_token.num_encoder_layers,
num_heads=config.rl_token.num_heads,
ff_dim=config.rl_token.ff_dim,
dropout=config.rl_token.dropout,
)
# RL-token decoder (used only during Stage 1 training)
self.rl_token_decoder = RLTokenDecoder(
rl_token_dim=config.rl_token.rl_token_dim,
output_dim=config.rl_token.input_dim,
num_layers=config.rl_token.num_decoder_layers,
num_heads=config.rl_token.num_heads,
ff_dim=config.rl_token.ff_dim,
dropout=config.rl_token.dropout,
)
# Actor MLP
self.actor = RLTActor(
state_dim=state_dim,
action_chunk_dim=action_chunk_dim,
hidden_dims=config.actor.hidden_dims,
std=config.actor.std,
)
self._action_dim = action_dim
self._action_chunk_dim = action_chunk_dim
self._state_dim = state_dim
self._proprioception_dim = proprioception_dim
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a refined action chunk given an observation.
Expects the observation dict to contain:
- ``"observation.vla_embeddings"``: VLA internal token embeddings ``(M, D)``
- ``"observation.reference_action"``: VLA reference chunk ``(C*d,)``
- ``"observation.state"`` (optional): proprioceptive state ``(P,)``
Returns:
Action chunk tensor of shape ``(C*d,)``.
"""
self.eval()
vla_emb = batch["observation.vla_embeddings"]
if vla_emb.dim() == 2:
vla_emb = vla_emb.unsqueeze(0)
z_rl = self.rl_token_encoder(vla_emb) # (1, D_rl)
parts = [z_rl]
if "observation.state" in batch and self._proprioception_dim > 0:
prop = batch["observation.state"]
if prop.dim() == 1:
prop = prop.unsqueeze(0)
parts.append(prop)
state = torch.cat(parts, dim=-1)
ref = batch["observation.reference_action"]
if ref.dim() == 1:
ref = ref.unsqueeze(0)
action = self.actor(state, ref)
return action.squeeze(0)
def reset(self):
pass
+373 -23
View File
@@ -15,11 +15,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import asdict
from typing import Literal
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
@@ -47,13 +52,20 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features[ACTION].shape[0]
self.encoder = SACObservationEncoder(config)
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
self._init_discrete_critic()
self._init_temperature()
def get_optim_params(self) -> dict:
optim_params = {
"actor": [self.actor.parameters()],
"actor": [
p
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["discrete_critic"] = self.discrete_critic.parameters()
@@ -71,9 +83,10 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
observations_features = None
if self.encoder.has_images:
observations_features = self.encoder.get_cached_image_features(batch)
if self.shared_encoder and self.actor.encoder.has_images:
observations_features = self.actor.encoder.get_cached_image_features(batch)
actions, _, _ = self.actor(batch, observations_features)
@@ -84,35 +97,372 @@ class SACPolicy(
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Actor forward pass."""
observations = batch.get("state", batch)
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means}
"""Compute the loss for the given model
def _init_actor(self, continuous_action_dim: int) -> None:
self.actor = Policy(
encoder=self.encoder,
network=MLP(input_dim=self.encoder.output_dim, **asdict(self.config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=False,
**asdict(self.config.policy_kwargs),
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_critic": loss_critic}
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
def _init_discrete_critic(self) -> None:
if self.config.num_discrete_actions is None:
self.discrete_critic = None
return
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def compute_loss_discrete_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder,
input_dim=self.encoder.output_dim,
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=self.shared_encoder,
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
class SACObservationEncoder(nn.Module):
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
+1 -3
View File
@@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
if self.dataset_meta is not None:
episodes_df = None
if self.sparse_subtask_names != ["task"]:
episodes_df = self.dataset_meta.episodes.to_pandas()
episodes_df = self.dataset_meta.episodes.to_pandas()
# Generate sparse targets
if self.sparse_temporal_proportions is not None:
@@ -85,7 +85,7 @@ class SmolVLAConfig(PreTrainedConfig):
scheduler_decay_lr: float = 2.5e-6
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
load_vlm_weights: bool = False # Set to False in case of training the expert from scratch. True when init from pretrained SmolVLA weights
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
def __post_init__(self):
super().__post_init__()
@@ -54,12 +54,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
from typing import TypedDict
from typing import TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
@@ -593,6 +592,12 @@ class VLAFlowMatching(nn.Module):
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
# Compile model if requested
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map=device,
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
@@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig):
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
# Tokenizer settings
action_tokenizer_path: str | None = "physical-intelligence/fast"
action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer"
# Action prediction mode: "diffusion" or "fast"
prediction_mode: str = "diffusion"
+12 -2
View File
@@ -261,10 +261,15 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
and optional LoRA fine-tuning support.
"""
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
def init_weights(self):
if getattr(self.model, "language_model", None) is not None:
return
super().init_weights()
@classmethod
def from_pretrained(
cls,
@@ -312,6 +317,11 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
processor.action_processor = action_tokenizer
else:
action_tokenizer = None
# add pad_token_id to config
config.pad_token_id = processor.tokenizer.pad_token_id
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
# Initialize model with configuration and processor
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
@@ -331,7 +341,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -21,6 +21,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@@ -38,6 +39,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.initializer_range = initializer_range
class Qwen2_5_VLConfig(PretrainedConfig):
@@ -11,7 +11,6 @@ from transformers.activations import ACT2FN
from transformers.cache_utils import (
Cache,
DynamicCache,
SlidingWindowCache,
StaticCache,
)
from transformers.generation import GenerationMixin
@@ -31,6 +30,15 @@ from transformers.utils import (
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
# always return False (which is the correct behavior when no sliding window cache is in use).
class _SlidingWindowCachePlaceholder:
pass
SlidingWindowCache = _SlidingWindowCachePlaceholder
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
@@ -594,19 +602,40 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return hidden_states
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
"""
compute default rope parameters for Qwen2_5_VL
"""
base = config.text_config.rope_parameters["rope_theta"]
dim = config.hidden_size // config.num_attention_heads
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, 1.0
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
self.rope_type = config.rope_parameters.get("rope_type", "default")
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
if self.rope_type == "default":
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
self.rope_kwargs = {}
else:
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
self.rope_kwargs = {}
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -1567,7 +1596,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
+2 -2
View File
@@ -144,7 +144,7 @@ def preprocesser_call(
"""
# Process image inputs
if images is not None and len(images) > 0:
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
@@ -152,7 +152,7 @@ def preprocesser_call(
# Process video inputs
if videos is not None:
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
video_grid_thw = videos_inputs["video_grid_thw"]
else:
videos_inputs = {}
@@ -276,6 +276,8 @@ class Florence2LanguageConfig(PretrainedConfig):
)
# ensure backward compatibility for BART CNN models
if not hasattr(self, "forced_bos_token_id"):
self.forced_bos_token_id = None
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
@@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = {
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
def __init__(self, config: Florence2LanguageConfig):
super().__init__(config)
@@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = {
"model.encoder.embed_tokens.weight": "model.shared.weight",
"model.decoder.embed_tokens.weight": "model.shared.weight",
}
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
def __init__(self, config: Florence2LanguageConfig):
@@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
FLORENCE2_START_DOCSTRING,
)
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
_tied_weights_keys = [
"language_model.encoder.embed_tokens.weight",
"language_model.decoder.embed_tokens.weight",
"language_model.lm_head.weight",
]
_tied_weights_keys = {
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
}
def __init__(self, config: Florence2Config):
super().__init__(config)
+5 -5
View File
@@ -17,7 +17,7 @@
from __future__ import annotations
from enum import Enum
from typing import Any, TypeAlias, TypedDict
from typing import Any, TypedDict
import numpy as np
import torch
@@ -36,10 +36,10 @@ class TransitionKey(str, Enum):
COMPLEMENTARY_DATA = "complementary_data"
PolicyAction: TypeAlias = torch.Tensor
RobotAction: TypeAlias = dict[str, Any]
EnvAction: TypeAlias = np.ndarray
RobotObservation: TypeAlias = dict[str, Any]
PolicyAction = torch.Tensor
RobotAction = dict[str, Any]
EnvAction = np.ndarray
RobotObservation = dict[str, Any]
EnvTransition = TypedDict(
@@ -131,15 +131,6 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -158,7 +149,6 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -208,7 +198,6 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -219,7 +208,6 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
+4 -4
View File
@@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
from typing import Any, TypedDict, TypeVar, cast
import torch
from huggingface_hub import hf_hub_download
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
@dataclass
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
class DataProcessorPipeline[TInput, TOutput](HubMixin):
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
This class chains together multiple `ProcessorStep` instances to form a complete
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
# Type aliases for semantic clarity.
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
class ObservationProcessorStep(ProcessorStep, ABC):
+1 -1
View File
@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
+19 -9
View File
@@ -61,7 +61,7 @@ from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.processor import TransitionKey
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.queue import get_last_item_from_queue
@@ -248,16 +248,16 @@ def act_with_policy(
logging.info("make_policy")
policy = make_policy(
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
# TODO: Re-enable processor pipeline once refactoring is validated against main
# preprocessor, postprocessor = None, None
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
@@ -288,6 +288,7 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
@@ -648,12 +649,12 @@ def interactions_stream(
# Policy functions
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
"""Load the latest policy weights from the learner."""
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
state_dicts = bytes_to_state_dict(bytes_state_dict)
# TODO: check encoder parameter synchronization possible issues:
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
# instead of the updated encoder params from critic (which is optimized separately)
@@ -663,9 +664,18 @@ def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue,
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
state_dicts = move_state_dict_to_device(state_dicts, device=device)
policy.load_state_dict(state_dicts)
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
# Utilities functions
-70
View File
@@ -1,70 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.rl.algorithms.base import (
RLAlgorithm,
RLAlgorithmConfig,
TrainingStats,
)
from lerobot.rl.algorithms.rlt import RLTAlgorithm, RLTAlgorithmConfig
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
def make_algorithm(
policy: torch.nn.Module,
policy_cfg,
*,
algorithm_name: str,
) -> RLAlgorithm:
"""Construct an :class:`RLAlgorithm` from a policy and its config.
Algorithm selection is explicit via ``algorithm_name`` (from
``cfg.algorithm``).
This is fully registry-driven adding a new algorithm only requires
registering an ``RLAlgorithmConfig`` subclass; no changes here.
The returned algorithm has **no optimizers** yet. On the learner side,
call ``algorithm.make_optimizers()`` afterwards to create them. On the
actor side (inference-only), leave them empty.
Args:
policy: Instantiated policy (e.g. ``SACPolicy``).
policy_cfg: The policy's ``PreTrainedConfig`` with the hyper-parameters
expected by the algorithm config's ``from_policy_config`` class-method.
algorithm_name: Algorithm registry key to instantiate.
"""
known = RLAlgorithmConfig.get_known_choices()
if algorithm_name not in known:
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
algo_config = config_cls.from_policy_config(policy_cfg)
return algo_config.build_algorithm(policy)
__all__ = [
"RLAlgorithm",
"RLAlgorithmConfig",
"TrainingStats",
"SACAlgorithm",
"SACAlgorithmConfig",
"RLTAlgorithm",
"RLTAlgorithmConfig",
"make_algorithm",
]
-183
View File
@@ -1,183 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes for RL algorithms.
Defines the abstract interface that every algorithm must implement, a registry
for algorithm configs, and a dataclass for training statistics.
"""
from __future__ import annotations
import abc
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import draccus
import torch
from torch import Tensor
from torch.optim import Optimizer
if TYPE_CHECKING:
from lerobot.rl.data_sources.data_mixer import DataMixer
BatchType = dict[str, Any]
@dataclass
class TrainingStats:
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
# Generic containers for all algorithms
losses: dict[str, float] = field(default_factory=dict)
grad_norms: dict[str, float] = field(default_factory=dict)
extra: dict[str, float] = field(default_factory=dict)
def to_log_dict(self) -> dict[str, float]:
"""Flatten all stats into a single dict for logging."""
d: dict[str, float] = {}
for name, val in self.losses.items():
d[name] = val
for name, val in self.grad_norms.items():
d[f"{name}_grad_norm"] = val
for name, val in self.extra.items():
d[name] = val
return d
@dataclass
class RLAlgorithmConfig(draccus.ChoiceRegistry):
"""Registry for algorithm configs."""
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
"""Construct the :class:`RLAlgorithm` for this config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
@classmethod
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
"""Build an algorithm config from a policy config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
class RLAlgorithm(abc.ABC):
"""Base for all RL algorithms."""
@abc.abstractmethod
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One complete training step.
The algorithm calls ``next(batch_iterator)`` as many times as it
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
The iterator is owned by the trainer; the algorithm just consumes
from it.
"""
...
def supports_offline_phase(self) -> bool:
"""Whether this algorithm has an offline pretraining phase.
Algorithms like RLT (RL-token training) or ConRFT (Cal-QL pretraining)
return ``True`` here. The learner checks this before the main online
loop and routes to :meth:`offline_update` accordingly.
"""
return False
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One offline training step (called before any online collection).
Only called when :meth:`supports_offline_phase` returns ``True``.
Uses the same iterator protocol as :meth:`update`.
"""
raise NotImplementedError(
f"{type(self).__name__} does not implement offline_update(). "
"Either override this method or return False from supports_offline_phase()."
)
def transition_to_online(self) -> None: # noqa: B027
"""Called once when switching from offline to online phase.
Use this to freeze modules trained offline, rebuild optimizers for the
online phase, reset step counters, etc.
Default is a no-op; subclasses override when they have an offline phase.
"""
def configure_data_iterator(
self,
data_mixer: DataMixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
) -> Iterator[BatchType]:
"""Create the data iterator this algorithm needs.
The default implementation uses the standard ``data_mixer.get_iterator()``.
Algorithms that need specialised sampling should override this method.
"""
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
def make_optimizers(self) -> dict[str, Optimizer]:
"""Create, store, and return the optimizers needed for training.
Called on the **learner** side after construction. Subclasses must
override this with algorithm-specific optimizer setup.
"""
return {}
def get_optimizers(self) -> dict[str, Optimizer]:
"""Return optimizers for checkpointing / external scheduling."""
return {}
@property
def optimization_step(self) -> int:
"""Current learner optimization step.
Part of the stable contract for checkpoint/resume. Algorithms can
either use this default storage or override for custom behavior.
"""
return getattr(self, "_optimization_step", 0)
@optimization_step.setter
def optimization_step(self, value: int) -> None:
self._optimization_step = int(value)
def get_weights(self) -> dict[str, Any]:
"""Policy state-dict to push to actors."""
return {}
@abc.abstractmethod
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner (inverse of ``get_weights``)."""
@torch.no_grad()
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
) -> tuple[Tensor | None, Tensor | None]:
"""Pre-compute observation features (e.g. frozen encoder cache).
Returns ``(None, None)`` when caching is not applicable.
"""
return None, None
@@ -1,83 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT algorithm configuration."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from lerobot.rl.algorithms.base import RLAlgorithmConfig
if TYPE_CHECKING:
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
@RLAlgorithmConfig.register_subclass("rlt")
@dataclass
class RLTAlgorithmConfig(RLAlgorithmConfig):
"""RLT-specific hyper-parameters that control the update loop."""
# ── Action chunks ──
chunk_size: int = 10
chunk_stride: int = 2
# ── Update cadence ──
utd_ratio: int = 5
policy_update_freq: int = 2
clip_grad_norm: float = 10.0
# ── Learning rates ──
actor_lr: float = 3e-4
critic_lr: float = 3e-4
rl_token_lr: float = 1e-4
# ── TD learning ──
discount: float = 0.99
tau: float = 0.005
num_critics: int = 2
# ── Policy constraint (paper Eq. 5) ──
bc_reg_coeff: float = 0.1
ref_dropout: float = 0.5
# ── Offline RL-token training ──
vla_finetune_weight: float = 0.0
@classmethod
def from_policy_config(cls, policy_cfg) -> RLTAlgorithmConfig:
"""Build from an existing ``RLTConfig`` (cfg.policy)."""
return cls(
chunk_size=policy_cfg.chunk_size,
chunk_stride=policy_cfg.chunk_stride,
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
clip_grad_norm=policy_cfg.clip_grad_norm,
actor_lr=policy_cfg.actor_lr,
critic_lr=policy_cfg.critic_lr,
rl_token_lr=policy_cfg.rl_token_lr,
discount=policy_cfg.discount,
tau=policy_cfg.tau,
num_critics=policy_cfg.num_critics,
bc_reg_coeff=policy_cfg.bc_reg_coeff,
ref_dropout=policy_cfg.ref_dropout,
vla_finetune_weight=policy_cfg.vla_finetune_weight,
)
def build_algorithm(self, policy: torch.nn.Module) -> RLTAlgorithm:
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
return RLTAlgorithm(policy=policy, config=self)
@@ -1,319 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) algorithm.
Implements the two-stage training from "RL Token: Bootstrapping Online RL
with Vision-Language-Action Models" (Xu et al., Physical Intelligence, 2026).
Stage 1 (offline): Train RL-token encoder/decoder via reconstruction loss.
Stage 2 (online): Train actor-critic with chunked TD, BC regularization,
reference-action pass-through, and reference-action dropout.
"""
from __future__ import annotations
import copy
from collections.abc import Iterator
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.rlt.modeling_rlt import MLP, RLTPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.algorithms.base import (
BatchType,
RLAlgorithm,
TrainingStats,
)
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
from lerobot.utils.constants import ACTION
class RLTCritic(nn.Module):
"""Q-function over (state, action_chunk) pairs.
Paper Eq. 3: Q_psi(x, a_{1:C})
Training-only component lives on the algorithm side, not in the policy.
"""
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int]):
super().__init__()
self.net = MLP(state_dim + action_chunk_dim, hidden_dims, output_dim=1)
def forward(self, state: Tensor, action_chunk: Tensor) -> Tensor:
x = torch.cat([state, action_chunk], dim=-1)
return self.net(x)
class RLTAlgorithm(RLAlgorithm):
"""RL Token: lightweight actor-critic on frozen VLA features.
Owns the ``RLTPolicy`` (RL-token encoder/decoder + actor), a critic
ensemble, and target networks. All VLA-specific logic (embedding
extraction, reference actions) lives in ``_prepare_forward_batch``.
"""
def __init__(self, policy: RLTPolicy, config: RLTAlgorithmConfig):
self.policy = policy
self.config = config
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
self._device = get_device_from_parameters(self.policy)
self._is_online = False
self._init_critics()
self._move_to_device()
# ── Initialization ───────────────────────────────────────────────
def _init_critics(self) -> None:
state_dim = self.policy._state_dim
action_chunk_dim = self.policy._action_chunk_dim
hidden_dims = self.policy.config.critic.hidden_dims
self.critics = torch.nn.ModuleList(
[RLTCritic(state_dim, action_chunk_dim, hidden_dims) for _ in range(self.config.num_critics)]
)
self.critic_targets = torch.nn.ModuleList([copy.deepcopy(c) for c in self.critics])
for ct in self.critic_targets:
ct.requires_grad_(False)
def _move_to_device(self) -> None:
self.critics.to(self._device)
self.critic_targets.to(self._device)
# ── Offline phase (Stage 1): RL-token training ───────────────────
def supports_offline_phase(self) -> bool:
return True
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Train RL-token encoder/decoder on demonstration data.
Paper Eq. 2: L_ro = E[ sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2 ]
"""
batch = next(batch_iterator)
vla_embeddings = batch["state"]["observation.vla_embeddings"].to(self._device)
z_vla = vla_embeddings.detach() # stop-gradient on VLA embeddings
z_rl = self.policy.rl_token_encoder(z_vla)
z_reconstructed = self.policy.rl_token_decoder(z_rl, z_vla)
loss_ro = F.mse_loss(z_reconstructed, z_vla)
self.optimizers["rl_token"].zero_grad()
loss_ro.backward()
torch.nn.utils.clip_grad_norm_(
list(self.policy.rl_token_encoder.parameters()) + list(self.policy.rl_token_decoder.parameters()),
max_norm=self.config.clip_grad_norm,
)
self.optimizers["rl_token"].step()
self._optimization_step += 1
return TrainingStats(losses={"loss_rl_token": loss_ro.item()})
def transition_to_online(self) -> None:
"""Freeze RL-token modules; rebuild optimizers for actor-critic only."""
self.policy.rl_token_encoder.requires_grad_(False)
self.policy.rl_token_decoder.requires_grad_(False)
self._is_online = True
self.optimizers = {
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
}
self._optimization_step = 0
# ── Online phase (Stage 2): Actor-Critic ─────────────────────────
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One full RLT update step with UTD critic warm-up.
Pulls ``utd_ratio`` batches. First ``utd_ratio - 1`` are critic-only;
the last batch also updates the actor (every ``policy_update_freq`` steps).
"""
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch)
self._critic_step(fb)
self._update_target_networks()
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch)
critic_loss = self._critic_step(fb)
stats = TrainingStats(losses={"loss_critic": critic_loss})
if self._optimization_step % self.config.policy_update_freq == 0:
actor_loss, bc_loss, q_val = self._actor_step(fb)
stats.losses["loss_actor"] = actor_loss
stats.extra["bc_loss"] = bc_loss
stats.extra["q_value_mean"] = q_val
self._update_target_networks()
self._optimization_step += 1
return stats
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
"""Convert a replay batch into algorithm-ready tensors.
Extracts RL-token from VLA embeddings, builds RL state, reads
reference action from complementary_info.
"""
obs = batch["state"]
next_obs = batch["next_state"]
device = self._device
vla_emb = obs["observation.vla_embeddings"].to(device)
next_vla_emb = next_obs["observation.vla_embeddings"].to(device)
with torch.no_grad():
z_rl = self.policy.rl_token_encoder(vla_emb)
z_rl_next = self.policy.rl_token_encoder(next_vla_emb)
parts = [z_rl]
next_parts = [z_rl_next]
if "observation.state" in obs and self.policy._proprioception_dim > 0:
prop = obs["observation.state"].to(device)
next_prop = next_obs["observation.state"].to(device)
parts.append(prop)
next_parts.append(next_prop)
state = torch.cat(parts, dim=-1)
next_state = torch.cat(next_parts, dim=-1)
action = batch[ACTION].to(device)
reward = batch["reward"].to(device)
done = batch["done"].to(device)
ref_action = None
comp_info = batch.get("complementary_info")
if comp_info is not None and "reference_action" in comp_info:
ref_action = comp_info["reference_action"].to(device)
return {
"state": state,
"next_state": next_state,
"action": action,
"reward": reward,
"done": done,
"reference_action": ref_action,
}
def _critic_step(self, fb: dict[str, Any]) -> float:
"""Paper Eq. 3: chunked TD with clipped double-Q target."""
state = fb["state"]
next_state = fb["next_state"]
action = fb["action"]
reward = fb["reward"]
done = fb["done"]
with torch.no_grad():
ref = fb.get("reference_action")
if ref is None:
ref = torch.zeros_like(action)
next_action = self.policy.actor(next_state, ref)
target_qs = [ct(next_state, next_action) for ct in self.critic_targets]
min_target_q = torch.min(torch.cat(target_qs, dim=-1), dim=-1, keepdim=True).values
discount_chunk = self.config.discount**self.config.chunk_size
td_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * discount_chunk * min_target_q
q_preds = [c(state, action) for c in self.critics]
loss = sum(F.mse_loss(q, td_target) for q in q_preds)
self.optimizers["critic"].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.critics.parameters(), max_norm=self.config.clip_grad_norm)
self.optimizers["critic"].step()
return loss.item()
def _actor_step(self, fb: dict[str, Any]) -> tuple[float, float, float]:
"""Paper Eq. 5: maximize Q while staying near VLA reference.
L_pi(theta) = E[ -Q(x, a) + beta * ||a - a_tilde||^2 ]
With reference-action dropout applied to the actor's ref input.
"""
state = fb["state"]
ref = fb.get("reference_action")
if ref is None:
ref = torch.zeros(state.shape[0], self.policy._action_chunk_dim, device=self._device)
# Reference-action dropout (paper Section IV-B)
mask = (torch.rand(ref.shape[0], 1, device=self._device) > self.config.ref_dropout).float()
ref_input = ref * mask
action = self.policy.actor(state, ref_input)
q_value = self.critics[0](state, action)
bc_loss = F.mse_loss(action, ref)
loss = -q_value.mean() + self.config.bc_reg_coeff * bc_loss
self.optimizers["actor"].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), max_norm=self.config.clip_grad_norm)
self.optimizers["actor"].step()
return loss.item(), bc_loss.item(), q_value.mean().item()
def _update_target_networks(self) -> None:
tau = self.config.tau
for critic, target in zip(self.critics, self.critic_targets, strict=True):
for p, tp in zip(critic.parameters(), target.parameters(), strict=True):
tp.data.copy_(tau * p.data + (1 - tau) * tp.data)
# ── Optimizer management ─────────────────────────────────────────
def make_optimizers(self) -> dict[str, Optimizer]:
"""Create optimizers. Initially for RL-token (Stage 1)."""
self.optimizers = {
"rl_token": torch.optim.Adam(
list(self.policy.rl_token_encoder.parameters())
+ list(self.policy.rl_token_decoder.parameters()),
lr=self.config.rl_token_lr,
),
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
}
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
# ── Weight sync ──────────────────────────────────────────────────
def get_weights(self) -> dict[str, Any]:
"""Push actor + RL-token encoder to actors (small footprint)."""
weights = {
"actor": self.policy.actor.state_dict(),
"rl_token_encoder": self.policy.rl_token_encoder.state_dict(),
}
return {k: {kk: vv.cpu() for kk, vv in v.items()} for k, v in weights.items()}
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
if "actor" in weights:
self.policy.actor.load_state_dict({k: v.to(device) for k, v in weights["actor"].items()})
if "rl_token_encoder" in weights:
self.policy.rl_token_encoder.load_state_dict(
{k: v.to(device) for k, v in weights["rl_token_encoder"].items()}
)
@@ -1,81 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SAC algorithm configuration."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig
from lerobot.rl.algorithms.base import RLAlgorithmConfig
if TYPE_CHECKING:
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
@RLAlgorithmConfig.register_subclass("sac")
@dataclass
class SACAlgorithmConfig(RLAlgorithmConfig):
"""SAC-specific hyper-parameters that control the update loop."""
utd_ratio: int = 1
policy_update_freq: int = 1
clip_grad_norm: float = 40.0
actor_lr: float = 3e-4
critic_lr: float = 3e-4
temperature_lr: float = 3e-4
discount: float = 0.99
temperature_init: float = 1.0
target_entropy: float | None = None
use_backup_entropy: bool = True
critic_target_update_weight: float = 0.005
num_critics: int = 2
num_subsample_critics: int | None = None
num_discrete_actions: int | None = None
shared_encoder: bool = True
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
use_torch_compile: bool = True
@classmethod
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat."""
return cls(
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
clip_grad_norm=policy_cfg.grad_clip_norm,
actor_lr=policy_cfg.actor_lr,
critic_lr=policy_cfg.critic_lr,
temperature_lr=policy_cfg.temperature_lr,
discount=policy_cfg.discount,
temperature_init=policy_cfg.temperature_init,
target_entropy=policy_cfg.target_entropy,
use_backup_entropy=policy_cfg.use_backup_entropy,
critic_target_update_weight=policy_cfg.critic_target_update_weight,
num_critics=policy_cfg.num_critics,
num_subsample_critics=policy_cfg.num_subsample_critics,
num_discrete_actions=policy_cfg.num_discrete_actions,
shared_encoder=policy_cfg.shared_encoder,
critic_network_kwargs=policy_cfg.critic_network_kwargs,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
use_torch_compile=policy_cfg.use_torch_compile,
)
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
return SACAlgorithm(policy=policy, config=self)
@@ -1,409 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SAC (Soft Actor-Critic) algorithm.
This module encapsulates all SAC-specific training logic (critic, actor,
temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface.
"""
from __future__ import annotations
import math
from collections.abc import Iterator
from dataclasses import asdict
from typing import Any
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.sac.modeling_sac import (
DISCRETE_DIMENSION_INDEX,
CriticEnsemble,
CriticHead,
DiscreteCritic,
SACObservationEncoder,
SACPolicy,
)
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.algorithms.base import (
BatchType,
RLAlgorithm,
TrainingStats,
)
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
from lerobot.utils.constants import ACTION
from lerobot.utils.transition import move_state_dict_to_device
class SACAlgorithm(RLAlgorithm):
"""Soft Actor-Critic with optional discrete-critic head.
Owns the ``SACPolicy`` and its optimizers. All loss methods call
``self.policy(batch_dict)`` rather than reaching into ``self.policy.actor``
directly, so any policy that returns ``{"action", "log_prob"}`` from its
``forward()`` is compatible.
"""
def __init__(
self,
policy: SACPolicy,
config: SACAlgorithmConfig,
):
self.policy = policy
self.config = config
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
self._device = get_device_from_parameters(self.policy)
self._init_critic_encoder()
self._init_critics()
self._init_temperature()
self._move_to_device()
def _init_critic_encoder(self) -> None:
"""Build or share the encoder used by critics."""
if self.config.shared_encoder:
self.critic_encoder = self.policy.encoder
self.policy.actor.encoder_is_shared = True
else:
self.critic_encoder = SACObservationEncoder(self.policy.config)
def _init_critics(self) -> None:
"""Build critic ensemble, targets, and optional discrete critic."""
action_dim = self.policy.config.output_features[ACTION].shape[0]
input_dim = self.critic_encoder.output_dim + action_dim
heads = [
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.critic_encoder, ensemble=heads)
target_heads = [
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.critic_encoder, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critic_target()
def _init_discrete_critic_target(self) -> None:
"""Build only the target discrete critic."""
input_dim = self.critic_encoder.output_dim
self.discrete_critic_target = DiscreteCritic(
encoder=self.critic_encoder,
input_dim=input_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (kmeftah) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha) and default target entropy."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
action_dim = self.policy.config.output_features[ACTION].shape[0]
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _move_to_device(self) -> None:
"""Move algorithm-owned modules to the policy device."""
self.critic_ensemble.to(self._device)
self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if hasattr(self, "discrete_critic_target"):
self.discrete_critic_target.to(self._device)
@property
def temperature(self) -> float:
return self.log_alpha.exp().item()
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Run one full SAC update with UTD critic warm-up.
Pulls ``utd_ratio`` batches from ``batch_iterator``. The first
``utd_ratio - 1`` batches are used for critic-only warm-up steps;
the last batch drives the full update (critic + actor + temperature).
"""
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
forward_batch = self._prepare_forward_batch(batch)
loss_critic = self._compute_loss_critic(forward_batch)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
torch.nn.utils.clip_grad_norm_(
self.critic_ensemble.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["critic"].step()
if self.config.num_discrete_actions is not None:
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
self.optimizers["discrete_critic"].zero_grad()
loss_discrete.backward()
torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["discrete_critic"].step()
self._update_target_networks()
batch = next(batch_iterator)
forward_batch = self._prepare_forward_batch(batch)
loss_critic = self._compute_loss_critic(forward_batch)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
self.critic_ensemble.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["critic"].step()
critic_loss_val = loss_critic.item()
stats = TrainingStats(
losses={"loss_critic": critic_loss_val},
grad_norms={"critic": critic_grad_norm},
)
if self.config.num_discrete_actions is not None:
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
self.optimizers["discrete_critic"].zero_grad()
loss_discrete.backward()
dc_grad = torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["discrete_critic"].step()
stats.losses["loss_discrete_critic"] = loss_discrete.item()
stats.grad_norms["discrete_critic"] = dc_grad
if self._optimization_step % self.config.policy_update_freq == 0:
for _ in range(self.config.policy_update_freq):
actor_loss = self._compute_loss_actor(forward_batch)
self.optimizers["actor"].zero_grad()
actor_loss.backward()
actor_grad = torch.nn.utils.clip_grad_norm_(
self.policy.actor.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["actor"].step()
temp_loss = self._compute_loss_temperature(forward_batch)
self.optimizers["temperature"].zero_grad()
temp_loss.backward()
temp_grad = torch.nn.utils.clip_grad_norm_(
[self.log_alpha],
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["temperature"].step()
stats.losses["loss_actor"] = actor_loss.item()
stats.losses["loss_temperature"] = temp_loss.item()
stats.grad_norms["actor"] = actor_grad
stats.grad_norms["temperature"] = temp_grad
stats.extra["temperature"] = self.temperature
self._update_target_networks()
self._optimization_step += 1
return stats
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
actions = batch[ACTION]
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
obs_features = batch.get("observation_feature")
next_obs_features = batch.get("next_observation_feature")
with torch.no_grad():
next_output = self.policy({"state": next_observations, "observation_feature": next_obs_features})
next_actions = next_output["action"]
next_log_probs = next_output["log_prob"]
q_targets = self.critic_target(next_observations, next_actions, next_obs_features)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
min_q, _ = q_targets.min(dim=0)
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
if self.config.num_discrete_actions is not None:
actions = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_ensemble(observations, actions, obs_features)
td_target_dup = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
critics_loss = (F.mse_loss(input=q_preds, target=td_target_dup, reduction="none").mean(dim=1)).sum()
return critics_loss
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
actions = batch[ACTION]
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
obs_features = batch.get("observation_feature")
next_obs_features = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete).long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties = complementary_info.get("discrete_penalty")
with torch.no_grad():
next_discrete_qs = self.policy.discrete_critic(next_observations, next_obs_features)
best_next_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
target_next_qs = self.discrete_critic_target(next_observations, next_obs_features)
target_next_q = torch.gather(target_next_qs, dim=1, index=best_next_action).squeeze(-1)
rewards_disc = rewards
if discrete_penalties is not None:
rewards_disc = rewards + discrete_penalties
target_q = rewards_disc + (1 - done) * self.config.discount * target_next_q
predicted_qs = self.policy.discrete_critic(observations, obs_features)
predicted_q = torch.gather(predicted_qs, dim=1, index=actions_discrete).squeeze(-1)
return F.mse_loss(input=predicted_q, target=target_q)
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
obs_features = batch.get("observation_feature")
output = self.policy({"state": observations, "observation_feature": obs_features})
actions_pi = output["action"]
log_probs = output["log_prob"]
q_preds = self.critic_ensemble(observations, actions_pi, obs_features)
min_q = q_preds.min(dim=0)[0]
return ((self.temperature * log_probs) - min_q).mean()
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
obs_features = batch.get("observation_feature")
with torch.no_grad():
output = self.policy({"state": observations, "observation_feature": obs_features})
log_probs = output["log_prob"]
return (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
def _update_target_networks(self) -> None:
tau = self.config.critic_target_update_weight
for target_p, p in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
):
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
if self.config.num_discrete_actions is not None:
for target_p, p in zip(
self.discrete_critic_target.parameters(),
self.policy.discrete_critic.parameters(),
strict=True,
):
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
"""Build the dict expected by loss computation from a sampled batch."""
observations = batch["state"]
next_observations = batch["next_state"]
observation_features, next_observation_features = self.get_observation_features(
observations, next_observations
)
forward_batch: dict[str, Any] = {
ACTION: batch[ACTION],
"reward": batch["reward"],
"state": observations,
"next_state": next_observations,
"done": batch["done"],
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
if "complementary_info" in batch:
forward_batch["complementary_info"] = batch["complementary_info"]
return forward_batch
def make_optimizers(self) -> dict[str, Optimizer]:
"""Create Adam optimizers for the SAC components and store them."""
actor_params = [
p
for n, p in self.policy.actor.named_parameters()
if not self.config.shared_encoder or not n.startswith("encoder")
]
self.optimizers = {
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
}
if self.config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam(
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
)
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
def get_weights(self) -> dict[str, Any]:
"""Policy state-dict to push to actors (includes actor + discrete critic)."""
return move_state_dict_to_device(self.policy.state_dict(), device="cpu")
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner."""
state = move_state_dict_to_device(weights, device=device)
self.policy.load_state_dict(state)
@torch.no_grad()
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
) -> tuple[Tensor | None, Tensor | None]:
if not self.config.shared_encoder:
return None, None
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
return None, None
if not self.policy.encoder.has_images:
return None, None
observation_features = self.policy.encoder.get_cached_image_features(observations)
next_observation_features = self.policy.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
-17
View File
@@ -1,17 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
-94
View File
@@ -1,94 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from typing import Any
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
BatchType = dict[str, Any]
class DataMixer(abc.ABC):
"""Abstract interface for all data mixing strategies.
Subclasses must implement ``sample(batch_size)`` and may override
``get_iterator`` for specialised iteration.
"""
@abc.abstractmethod
def sample(self, batch_size: int) -> BatchType:
"""Draw one batch of ``batch_size`` transitions."""
...
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Infinite iterator that yields batches.
The default implementation repeatedly calls ``self.sample()``.
Subclasses with underlying buffer iterators (async prefetch)
should override this for better throughput.
"""
while True:
yield self.sample(batch_size)
class OnlineOfflineMixer(DataMixer):
"""Mixes transitions from an online and an optional offline replay buffer.
When both buffers are present, each batch is constructed by sampling
``ceil(batch_size * online_ratio)`` from the online buffer and the
remainder from the offline buffer, then concatenating.
This mixer assumes both online and offline buffers are present.
"""
def __init__(
self,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer | None = None,
online_ratio: float = 1.0,
):
if not 0.0 <= online_ratio <= 1.0:
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
self.online_buffer = online_buffer
self.offline_buffer = offline_buffer
self.online_ratio = online_ratio
def sample(self, batch_size: int) -> BatchType:
if self.offline_buffer is None:
return self.online_buffer.sample(batch_size)
n_online = max(1, int(batch_size * self.online_ratio))
n_offline = batch_size - n_online
online_batch = self.online_buffer.sample(n_online)
offline_batch = self.offline_buffer.sample(n_offline)
return concatenate_batch_transitions(online_batch, offline_batch)
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Yield batches from online/offline mixed sampling."""
while True:
yield self.sample(batch_size)
+284 -92
View File
@@ -65,11 +65,9 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
from lerobot.rl.algorithms import make_algorithm
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.trainer import RLTrainer
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -95,7 +93,7 @@ from lerobot.utils.train_utils import (
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.transition import move_transition_to_device
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
get_safe_torch_device,
@@ -266,8 +264,8 @@ def add_actor_information_and_train(
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Delegates training updates to an ``RLAlgorithm`` (currently ``SACAlgorithm``).
- Periodically pushes updated weights to actors.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
@@ -286,15 +284,17 @@ def add_actor_information_and_train(
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm
online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
fps = cfg.env.fps
log_freq = cfg.log_freq
save_freq = cfg.save_freq
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.async_prefetch
queue_size = cfg.queue_size
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing
if not use_threads(cfg):
@@ -306,7 +306,7 @@ def add_actor_information_and_train(
logging.info("Initializing policy")
policy = make_policy(
policy: SACPolicy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
@@ -315,24 +315,19 @@ def add_actor_information_and_train(
policy.train()
algorithm = make_algorithm(
policy=policy,
policy_cfg=cfg.policy,
algorithm_name=cfg.algorithm,
)
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
# TODO: Re-enable processor pipeline once refactoring is validated against main
preprocessor, postprocessor = None, None
# Push initial policy weights to actors (same path as periodic push)
state_bytes = state_to_bytes(algorithm.get_weights())
parameters_queue.put(state_bytes)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
total_batch_size = cfg.batch_size
batch_size = cfg.batch_size
offline_replay_buffer = None
if cfg.dataset is not None:
@@ -341,70 +336,20 @@ def add_actor_information_and_train(
device=device,
storage_device=storage_device,
)
# DataMixer: online-only or online/offline 50-50 mix
data_mixer = OnlineOfflineMixer(
online_buffer=replay_buffer,
offline_buffer=offline_replay_buffer,
online_ratio=cfg.online_ratio,
)
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=total_batch_size,
preprocessor=preprocessor,
action_dim=cfg.policy.output_features["action"].shape[0],
async_prefetch=async_prefetch,
queue_size=queue_size,
)
# If we are resuming, we need to load the training state
optimizers = algorithm.get_optimizers()
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
logging.info("Starting learner thread")
interaction_message = None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
algorithm.optimization_step = optimization_step
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
dataset_repo_id = None
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# ── Offline phase (e.g. RLT RL-token training, ConRFT Cal-QL pretraining) ──
offline_steps = getattr(cfg.policy, "offline_steps", 0)
if algorithm.supports_offline_phase() and offline_steps > 0 and offline_replay_buffer is not None:
logging.info(f"[LEARNER] Starting offline phase ({offline_steps} steps)")
offline_mixer = OnlineOfflineMixer(
online_buffer=offline_replay_buffer,
offline_buffer=None,
online_ratio=1.0,
)
offline_iterator = algorithm.configure_data_iterator(
data_mixer=offline_mixer,
batch_size=total_batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
for step in range(offline_steps):
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown during offline phase. Exiting...")
return
stats = algorithm.offline_update(offline_iterator)
if step % log_freq == 0:
logging.info(f"[LEARNER] Offline step {step}/{offline_steps}: {stats.to_log_dict()}")
if wandb_logger:
log_dict = stats.to_log_dict()
log_dict["offline_step"] = step
wandb_logger.log_dict(d=log_dict, mode="train", custom_step_key="offline_step")
algorithm.transition_to_online()
optimizers = algorithm.get_optimizers()
logging.info("[LEARNER] Offline phase complete, transitioned to online")
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
@@ -435,22 +380,180 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
time_for_one_optimization_step = time.time()
if online_iterator is None:
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
stats = trainer.training_step()
if offline_replay_buffer is not None and offline_iterator is None:
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["discrete_critic"].step()
# Update target networks (main and discrete)
policy.update_target_networks()
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
critic_output = policy.forward(forward_batch, model="critic")
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
# Initialize training info dictionary
training_infos = {
"loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["discrete_critic"].step()
# Add discrete critic info to training info
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Actor optimization
actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad()
loss_temperature.backward()
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
state_dicts = algorithm.get_weights()
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
training_infos = stats.to_log_dict()
# Update target networks (main and discrete)
policy.update_target_networks()
# Log training metrics at specified intervals
optimization_step = algorithm.optimization_step
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
@@ -478,6 +581,7 @@ def add_actor_information_and_train(
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
@@ -494,8 +598,6 @@ def add_actor_information_and_train(
offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id,
fps=fps,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
@@ -580,8 +682,6 @@ def save_training_checkpoint(
offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None,
fps: int = 30,
preprocessor=None,
postprocessor=None,
) -> None:
"""
Save training checkpoint and associated data.
@@ -605,8 +705,6 @@ def save_training_checkpoint(
offline_replay_buffer: Optional offline replay buffer to save
dataset_repo_id: Repository ID for dataset
fps: Frames per second for dataset
preprocessor: Optional preprocessor pipeline to save
postprocessor: Optional postprocessor pipeline to save
"""
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
@@ -623,8 +721,6 @@ def save_training_checkpoint(
policy=policy,
optimizer=optimizers,
scheduler=None,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
# Save interaction step manually
@@ -662,6 +758,58 @@ def save_training_checkpoint(
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_discrete_critic = torch.optim.Adam(
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["discrete_critic"] = optimizer_discrete_critic
return optimizers, lr_scheduler
# Training setup functions
@@ -866,6 +1014,33 @@ def initialize_offline_replay_buffer(
# Utilities/Helpers functions
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
return cfg.policy.concurrency.learner == "threads"
@@ -916,6 +1091,23 @@ def check_nan_in_transition(
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
# Create a dictionary to hold all the state dicts
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
# Add discrete critic if it exists
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
policy.discrete_critic.state_dict(), device="cpu"
)
logging.debug("[LEARNER] Including discrete critic in state dict push")
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
):

Some files were not shown because too many files have changed in this diff Show More