Compare commits

..

86 Commits

Author SHA1 Message Date
jade.choghari@huggingface.co 8e77625cf1 clean 2025-11-19 16:22:01 +00:00
jade.choghari@huggingface.co 514e3cd83d add convert scripts 2025-11-19 16:11:07 +00:00
Jade Choghari 25ef3c520c fix test, add to train 2025-11-18 23:58:27 +01:00
Jade Choghari 464f14a0f4 fix imports 2025-11-18 18:41:01 +01:00
Jade Choghari 61f8c8a0f2 add docs 2025-11-18 18:37:27 +01:00
Jade Choghari b196f04d48 more changes 2025-11-18 18:29:19 +01:00
Jade Choghari 91768a9879 Merge branch 'main' into refactor/env-libero 2025-11-18 17:43:48 +01:00
Jade Choghari 4fb41d3e5a style 2025-11-18 17:29:25 +01:00
Jade Choghari b257e02ccf add env processor 2025-11-18 17:27:43 +01:00
Michel Aractingi b464d9f8bc Fix episode filtering bug when requesting a subset of the episodes in a dataset (#2456)
* filter episodes in load_nested_dataset

* nit

* remove test filtering

* move import to module level

* added missing episode indices to the EpisodeAwareSampler in lerobot_train.py;
2025-11-18 17:26:41 +01:00
Jade Choghari 8915c6cd25 more changes 2025-11-18 16:28:18 +01:00
Jade Choghari 0ed2f87fba iterate on review: 2025-11-18 16:12:09 +01:00
Jade Choghari a068618faf more styling fixes: 2025-11-18 16:01:41 +01:00
Jade Choghari 769eb27c87 more fixes 2025-11-18 15:54:14 +01:00
Jade Choghari c3b8f65a8c put axis-1 2025-11-18 15:40:23 +01:00
Jade Choghari 33a8d31af0 Merge branch 'main' into refactor/env-libero 2025-11-18 15:33:45 +01:00
Jade Choghari 6c9f169996 clean 2025-11-18 15:26:12 +01:00
jade.choghari@huggingface.co 9979b62c52 more 2025-11-18 14:57:43 +01:00
jade.choghari@huggingface.co e91e48b79c fix style 2025-11-18 14:25:48 +01:00
jade.choghari@huggingface.co b4b5d057b1 more fixes 2025-11-18 14:24:59 +01:00
jade.choghari@huggingface.co 9a115c303c more changes 2025-11-18 14:13:17 +01:00
jade.choghari@huggingface.co e5efb6b6dc working changes 2025-11-18 13:55:45 +01:00
Jade Choghari 4c67330430 more changes 2025-11-18 10:33:23 +01:00
Michel Aractingi 784cdae55a Fixes in port droid scripts (#2455)
* Fixes in port droid scripts

* revert default mem-per-cpu

* style nit

* fix relative imports

* style nit
2025-11-17 23:42:30 +01:00
Steven Palma d9e74a9d37 chore(dependencies): Bump lerobot to 0.4.2 (#2423) 2025-11-12 13:13:57 +01:00
Steven Palma a5b29d4301 chore(installation): remove libero installation patch (#2416)
* chore(installation): remove libero installation patch

* fix(ci): exclude groot for unbound deps test
2025-11-10 11:51:52 +01:00
Steven Palma a4aa316470 fix(dataset): fix data access bottleneck for faster training (#2408) 2025-11-07 21:54:44 +01:00
Michel Aractingi f6b16f6d97 fix(dataset_tools) Critical bug in modify features (#2342)
* fix bug in `_copy_data_with_feature_changes`

* Update src/lerobot/datasets/dataset_tools.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

* add missing import

---------

Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2025-11-04 15:56:41 +01:00
Jade Choghari df0c335a5a feat(sim): EnvHub - allow loading envs from the hub (#2121)
* add env from the hub support

* add safe loading

* changes

* add tests, docs

* more

* style/cleaning

* order

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-04 14:52:46 +01:00
Jade Choghari 87ed3a2b6e dep(upgrade): add libero as a pypi package (#2365)
* add changes

* Update pyproject.toml

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* add openpi-transformers

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* new changes

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* Update hf-libero version in pyproject.toml

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-04 10:43:52 +01:00
Jade Choghari d57d1aa197 fix(make_policy): rename mapping edge cases in training (#2332)
* fix bug

* update fixes

* add hf license

* more fixes

* add transformers

* iterate on review

* more fixes

* more fixes

* add a False test

* reduce img size

* reduce img size

* skip the test

* add

* add style
2025-10-31 13:08:42 +01:00
Caroline Pascal 3f8c5d9809 fix(video_key typo): fixing video_key typo in update_video_info (#2323) 2025-10-28 09:41:33 +01:00
Steven Palma d1548e1d13 docs(install): imrpove groot and libero installation instructions (#2314) 2025-10-26 15:37:41 +08:00
Steven Palma d11ec6b5ef docs(readme): update installation instructions for 0.4.0 (#2310) 2025-10-24 17:31:37 +02:00
Steven Palma c75455a6de chore(dependecies): Bump lerobot to 0.4.1 (#2299)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-23 20:59:30 +02:00
Steven Palma f25ac02e6c chore(dependencies): Bump lerobot to 0.4.0 (#2298)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-23 20:20:52 +02:00
Steven Palma 23cb668cac fix(ci): add fastapi dep + bump to 0.3.5 (#2301) 2025-10-23 19:53:44 +02:00
Steven Palma 2ea3043b1b patch(ci): remove pi & libero tags from PyPi release temporary due to their reliance on git dependencies (#2300) 2025-10-23 19:37:11 +02:00
Steven Palma 0f61e2415f chore(deps): update requirements file (#2297) 2025-10-23 18:38:41 +02:00
Michel Aractingi 76a425c600 Fix: check_cached_episodes doesn't check if the requested episode video were downloaded (#2296)
* In `check_cached_episodes_sufficient` check whether all the requested video files are downloaded

* optimize loop over the video paths

* revert example num_workers

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

* set num_workers to zero in example

* style nit

* reintroduce copilot optim

---------

Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-23 17:34:03 +02:00
Lior Ben Horin df71f3ce24 docs(policies): GR00T updates (#2293)
* Update Libero beval results + fix phrasing

* style of GR00T wording
2025-10-23 15:01:41 +02:00
Francesco Capuano 326aca0a48 Add API Examples (#2289)
* (unscrewing things up) (#2288)

* fix: expose a function explicitly building a frame for inference

* fix: first make dataset frame, then make ready for inference

* fix: reducing reliance on lerobot record for policy's ouptuts too

* fix: encapsulating squeezing out + device handling from predict action

* fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion)

* refactor(envs): add custom-observation-size (#2167)

* fix: add MockMotorBus to MockRobot

* rl: first drafts

* add: all components of HIL SERL

* fix: actor block works

* fix: less friction, less friction

* add: hil-serl complete example

* fix: dataset names

* fix: restructuring example folder

* fix: act works but found bug in how ACT works

* fix: same path for both pre and postprocessors

* fix: paths

* add: example usage for act

* add: using ACT example

* fix: training examples

* fix: using examples

* fix: camera index

* fix: rename workflows into tutorial so that the path of the files is lerobot/examples/tutorial/...

* fix: upload everything in one repo

* fix: model name

* fix: simplify model path

* add: VLAs example

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* fix: minor fix using named attributes

* fix: change model to act

* fix: named attributes for inference frame building

* fix: minor fixes to smolvla

* fix: small changes to pi0

* remove: old file that should have never been committed (ups sorry sorry)

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
2025-10-23 14:18:13 +02:00
Steven Palma be46bdea8f feat(policies): add Nvidia Gr00t N1.5 model (#2292)
* feat(policies): add Nvidia Gr00t N1.5 model

Co-authored-by: lbenhorin <lbenhorin@nvidia.com>
Co-authored-by: Aravindh <aravindhs@nvidia.com>
Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com>
Co-authored-by: youliangt <youliangt@nvidia.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>

* fix(docs): add groot to index

Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>

---------

Co-authored-by: lbenhorin <lbenhorin@nvidia.com>
Co-authored-by: Aravindh <aravindhs@nvidia.com>
Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com>
Co-authored-by: youliangt <youliangt@nvidia.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>
2025-10-23 13:50:30 +02:00
Steven Palma 306429a85b fix(cameras): opencv camera index casting (#2286) 2025-10-22 17:27:31 +02:00
Michel Aractingi 12f2f35760 - Introduce _current_file_start_frame for better tracking of the number of frames in each parquet file (#2280)
- Added testing for that section in `test_datasets.py`
2025-10-21 16:17:12 +02:00
Jade Choghari a024d33750 fix(bug): Fix policy renaming ValueError during training (#2278)
* fixes

* style

* Update src/lerobot/policies/factory.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* style

* add review fixes

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-21 16:00:46 +02:00
Hakjin Lee 63cd2111ad [Fix] Device Error on SmolVLA Multi-GPU Training (#2270)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-21 14:26:31 +02:00
Steven Palma abe9e79825 chore(dependencies): bump & ceil gymnasium version + pin metaworld version + bump gym-hil (#2267)
* chore(dependencies): bump & ceil gymnasium version + pin metaworld version

Co-authored-by: Jade Choghari <chogharijade@gmail.com>

* chore(dependencies): bump gym-hil to be compatible

---------

Co-authored-by: Jade Choghari <chogharijade@gmail.com>
2025-10-21 12:56:32 +02:00
Steven Palma 503fc4e9f4 fix(ci): exclude motor tests in multi-gpu setup (#2276) 2025-10-21 12:14:26 +02:00
Xiaoxuan Liu 92b479f9ac Fix camera FPS set issue (#2275)
Set camera width/height 1st before FPS setting, to avoid FPS set failure alike:

ERROR:__main__:Failed to connect or configure OpenCV camera /dev/video2: OpenCVCamera(/dev/video2) failed to set fps=30 (actual_fps=25.0).
2025-10-21 11:31:03 +02:00
Steven Palma b954337ac7 fix(scripts): add missing observation overwrite in eval and async (#2265) 2025-10-20 23:34:24 +02:00
Jade Choghari 5f6f476f32 fix: support cuda:0, cuda:1 in string selection (#2256)
* fix

* update func 2

* update nightly

* fix quality

* ignore test_dynamixel
2025-10-20 23:29:05 +02:00
Antoine 502fdc0630 fix dataset revision (#2260)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-20 18:45:09 +02:00
Steven Palma 9db6213895 chore(style): update mypy config (#2257)
* chore(style): update mypy config

* fix(cameras): mypy check
2025-10-20 16:25:03 +02:00
hls aa1d906802 Enhance OpenCVCamera with FOURCC for MJPEG support and validation (#1558)
* Enhance OpenCVCamera with FOURCC support and validation

- Added FOURCC configuration option to OpenCVCamera and OpenCVCameraConfig for specifying video format.
- Implemented _validate_fourcc method to validate and set the camera's FOURCC code.
- Updated _configure_capture_settings to apply FOURCC settings before FPS and resolution.
- Enhanced camera detection to include default FOURCC code in camera info.
- Updated documentation to reflect new FOURCC parameter and its implications on performance.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add tests for FOURCC configuration in OpenCVCamera

- Implemented tests to validate FOURCC configuration and its application in OpenCVCamera.
- Added checks for valid FOURCC codes and ensured that invalid codes raise appropriate errors.
- Included a test for camera connection functionality using specified FOURCC settings.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix circular import in __init__.py - change to relative import

* Update src/lerobot/cameras/opencv/configuration_opencv.py

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: hls <56255627+forgetwhatuwant@users.noreply.github.com>

* Update src/lerobot/cameras/opencv/configuration_opencv.py

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: hls <56255627+forgetwhatuwant@users.noreply.github.com>

* fix(camera_opencv): ensure MSMF hardware transform compatibility on Windows before importing OpenCV

* This change reverts the import from a relative import (.) back to the absolute import (lerobot.) as it was previously

* opencv/config: satisfy Ruff SIM102 by merging nested if for fourcc validation

* style(opencv/config): apply ruff-format changes

---------

Signed-off-by: hls <56255627+forgetwhatuwant@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: forgetwhatuwant <forgetwhatuwant@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-20 14:19:21 +02:00
tetsugo02 eff8a6fd12 Fix typehint and address the mypy errors of src/lerobot/configs (#1746)
* fix: update policy handling and type annotations
added typehint and addressed the error of mypy

* fix: rename should_push_to_hub to push_to_hub
I find that there are other dependencies of push_to_hub so I fix the property name back to original one.

* fix: typo

* fix: changed the position of try-except block
As the copilot said, use raise before `hf_hub_download` would stop program even it is able to download

* fix: update pre-commit configuration and mypy settings
add args: --follow-imports=silent to pass error which have no relationship with src/lerobot/configs

* fix: remove the specific path in .pre-commit-config.yaml

* feat: enhance typehint to adapt mypy strict mode.

* fix: remove duplicate FileNotFoundError check in PreTrainedConfig

* fix: make "pre-commit run --all-files" pass

* fix: replace logging with logger for better logging practices

* fix: fixed extra changes of lint and  format changes

* fix: fixed extra changes out of "configs" module

* Update src/lerobot/configs/policies.py

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com>

* fix: add logging for scratch job

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Signed-off-by: tetsugo02 <131431116+tetsugo02@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-20 12:57:32 +02:00
Jaisree25 c54cd529a2 Fix: camera code changes only (#1788) 2025-10-20 12:57:10 +02:00
Huy a5ca206c49 chore(mypy-compliant): Ensure the model module passes MyPy type checks (#1782)
* feat(mypy-compliant): Ensure the model module passes MyPy type checks

* fix

* uncomment pyproject.toml for model module

* fix

* fix

---------

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-19 23:35:21 +02:00
Bryson Jones 88100943ef add affine transforms and test (#2145)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-19 21:39:30 +02:00
Jade Choghari a95b15ccc0 refactor(env): introduce explicit gym ID handling in EnvConfig/factory (#2234)
* refactor(env): introduce explicit gym ID handling in EnvConfig/factory

This commit introduces properties for the gym package/ID associated
with and environment config. They default to the current defaults
(`gym_{package_name}/{task_id}`) to avoid breaking changes, but allow
for easier use of external gym environments.

Subclasses of `EnvConfig` can override the default properties to allow
the factory to import (i.e. register) the gym env from a specific module,
and also instantiate the env from any ID string.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more changes

* quality

* fix test

---------

Co-authored-by: Ben Sprenger <ben.sprenger@rogers.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-10-19 20:50:00 +02:00
Xingdong Zuo a97d078d95 Feat: Support CLI for Launching LeKiwiHost (#1614)
* Support CLI for LeKiwiHost

Signed-off-by: Xingdong Zuo <zuoxingdong@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Xingdong Zuo <zuoxingdong@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-19 20:19:57 +02:00
Steven Palma 98662e5f24 chore(install): use miniforge instead of miniconda (#2249)
Co-authored-by: Silvio Traversaro <silvio@traversaro.it>
2025-10-19 19:19:21 +02:00
Caroline Pascal 4d8f242af9 chore(pyproject): cleaning no longer existing files/folders in pyproject exclude_dirs (#2240) 2025-10-19 14:43:07 +02:00
Francesco Capuano 1ff8986c77 fix: add MockMotorBus to MockRobot (#2081) 2025-10-18 12:06:43 +02:00
Lycoris f0aeded142 Fixes failed to delete images because the timing of gc is uncertain (#1710)
* Prevents resource leak in video_utils when getting width and height

Added the with statement when opening the image to ensure that the file handle is properly closed after its contents are read. 
Otherwise, shutil.rmtree(img_dir) will fail when called after the encode_video_frames function completes.

Signed-off-by: Lycoris <32864669+lycoris1129@users.noreply.github.com>

---------

Signed-off-by: Lycoris <32864669+lycoris1129@users.noreply.github.com>
2025-10-18 06:47:07 +02:00
Steven Palma da5d2f3e91 chore(dependencies): upgrade rerun (#2237)
* chore(dependencies): upgrade rerun

Co-authored-by: Ben Zhang <benzhangniu@gmail.com>

* test(utils): fix rerun scalars

---------

Co-authored-by: Ben Zhang <benzhangniu@gmail.com>
2025-10-18 01:35:02 +02:00
Steven Palma d6ea3bbce0 fix(docs): update example flags for lerobot-dataset-viz (#2238)
Co-authored-by: Yingjie Wei <yingjie.wei@cern.ch>
Co-authored-by: DWarez <ldwarezl@gmail.com>
2025-10-18 01:34:44 +02:00
pre-commit-ci[bot] 7aedbbf81a [pre-commit.ci] pre-commit autoupdate (#1563)
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/pre-commit/pre-commit-hooks: v5.0.0 → v6.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v5.0.0...v6.0.0)
- [github.com/astral-sh/ruff-pre-commit: v0.12.4 → v0.13.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.12.4...v0.13.0)
- [github.com/adhtruong/mirrors-typos: v1.34.0 → v1.36.2](https://github.com/adhtruong/mirrors-typos/compare/v1.34.0...v1.36.2)
- [github.com/gitleaks/gitleaks: v8.27.2 → v8.28.0](https://github.com/gitleaks/gitleaks/compare/v8.27.2...v8.28.0)
- [github.com/woodruffw/zizmor-pre-commit: v1.11.0 → v1.13.0](https://github.com/woodruffw/zizmor-pre-commit/compare/v1.11.0...v1.13.0)

* chore: update pre-commit versions

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-18 01:20:45 +02:00
Steven Palma 1ee8d824f5 fix(docs): update eval example (#2236)
Co-authored-by: Hemanth M <ee24b024@smail.iitm.ac.in>
2025-10-18 00:51:17 +02:00
Maximilian Li f7c4f99545 fix(factory): ensure output and input features are set only if not already defined (#1771)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-18 00:50:34 +02:00
Steven Palma 92b6254473 feat(utils): add support for Intel XPU backend (#2233)
* feat: add support for Intel XPU backend in device selection

* Update src/lerobot/utils/utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: update is_amp_available to include xpu as a valid device

* Update src/lerobot/utils/utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>

* Update src/lerobot/utils/utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>

* fix: remove unused return and add comments on fp64 fallback handling

* fix(utils): return dtype in case xpu has fp64

---------

Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>
Co-authored-by: Lim, Xiang Yang <xiang.yang.lim@intel.com>
Co-authored-by: Lim Xiang Yang <xiangyang95@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
2025-10-17 19:30:25 +02:00
Ilia Larchenko 79137f58d1 Fixed a small wrist flex calibration issue for lekiwi (#1787)
wrist_flex is not full_turn_motor (it has only a 180-degree range) and should be calibrated like in so_100, only wrist_roll is a full turn motor

Signed-off-by: Ilia Larchenko <41329713+IliaLarchenko@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-17 18:14:53 +02:00
azaracla da9c2e66f4 fix: fix deprecated hugginface-cli whoami (#1884)
Signed-off-by: azaracla <33293244+azaracla@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-17 17:26:34 +02:00
Steven Palma 45730cc71e fix(docs): markdown formatting in integrate_hardware.mdx (#2232)
* Fixing some markdown formatting in the Step 4 section

* fix(docs): code block format

---------

Co-authored-by: Doug Harris <dharris@gmail.com>
2025-10-17 16:33:46 +02:00
yfynb1111 5d4af4b0b1 Fix: debug policy load pretrained model failure problem (#2073)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-17 16:32:56 +02:00
Edgar Riba 0050d7c61c docs: change video file path format in conversion script (#2113)
* Change video file path format in conversion script

Updated video file path in the dataset conversion script.

Signed-off-by: Edgar Riba <edgar.riba@gmail.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Edgar Riba <edgar.riba@gmail.com>

---------

Signed-off-by: Edgar Riba <edgar.riba@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-10-17 16:32:24 +02:00
Jade Choghari cf2897f545 Docs(fix): corrects minor mix-ups encoder/decoder (#2231) 2025-10-17 16:12:01 +02:00
Steven Palma 2c18210d02 chore(robots): deprecate strech, vipex and widowx robots (#2205) 2025-10-17 15:36:19 +02:00
dependabot[bot] 44bf283701 chore(deps): bump pypa/gh-action-pypi-publish (#1870)
Bumps the github_actions group with 1 update in the /.github/workflows directory: [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish).


Updates `pypa/gh-action-pypi-publish` from 1.12.4 to 1.13.0
- [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases)
- [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.12.4...v1.13.0)

---
updated-dependencies:
- dependency-name: pypa/gh-action-pypi-publish
  dependency-version: 1.13.0
  dependency-type: direct:production
  dependency-group: github_actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-17 15:33:37 +02:00
Antoine a51682b266 Optimized episode cache verification (#2166)
Signed-off-by: Antoine <antoine.dandigne@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-10-17 15:18:21 +02:00
Robin Glauser ed49c9935a Adding magnitude encoding bits for feetech motors according to https://github.com/Kotakku/FT_SCServo_Debug_Qt/blob/master/servo/sms_sts.h and https://gitee.com/ftservo/FTServo_Python/blob/main/scservo_sdk/sms_sts.py (#2223) 2025-10-17 15:15:03 +02:00
Infinity4B 52455d03a7 fix eval-related doc errors (#2183)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-10-17 14:34:21 +02:00
Steven Palma 4afb253825 fix(dependencies): wandb > 0.22.0 uses a different version of protobuf (#2230) 2025-10-17 13:59:31 +02:00
Steven Palma 96c664e09f fix(scripts): warmup in find cameras script (#2229) 2025-10-17 13:59:10 +02:00
Steven Palma 8bd0aec618 chore(ci): relax stale bot for PRs (#2222) 2025-10-16 17:44:50 +02:00
Pepijn e82e7a02e9 feat(train): add accelerate for multi gpu training (#2154)
* Enhance training and logging functionality with accelerator support

- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions.
- Updated `update_policy` to handle gradient updates based on the presence of an accelerator.
- Modified logging to prevent duplicate messages in non-main processes.
- Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage.
- Updated `MetricsTracker` to account for the number of processes when calculating metrics.
- Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.

* Initialize logging in training script for both main and non-main processes

- Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode.
- This change enhances the clarity and consistency of logging during training sessions.

* add docs and only push model once

* Place  logging under accelerate and update docs

* fix pre commit

* only log in main process

* main logging

* try with local rank

* add tests

* change runner

* fix test

* dont push to hub in multi gpu tests

* pre download dataset in tests

* small fixes

* fix path optimizer state

* update docs, and small improvements in train

* simplify accelerate main process detection

* small improvements in train

* fix OOM bug

* change accelerate detection

* add some debugging

* always use accelerate

* cleanup update method

* cleanup

* fix bug

* scale lr decay if we reduce steps

* cleanup logging

* fix formatting

* encorperate feedback pr

* add min memory to cpu tests

* use accelerate to determin logging

* fix precommit and fix tests

* chore: minor details

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-10-16 17:41:55 +02:00
129 changed files with 9454 additions and 1865 deletions
+1 -1
View File
@@ -78,7 +78,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install lerobot with all extras
run: uv sync --all-extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
+2 -1
View File
@@ -189,5 +189,6 @@ jobs:
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10
+11 -3
View File
@@ -82,6 +82,14 @@ jobs:
exit 1
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
run: python -m pip install build
@@ -103,7 +111,7 @@ jobs:
- name: Publish to TestPyPI for pre-releases
# True for tags like 'v0.2.0-rc1'
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
repository-url: https://test.pypi.org/legacy/
verbose: true
@@ -111,7 +119,7 @@ jobs:
- name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
verbose: true
print-hash: true
@@ -138,7 +146,7 @@ jobs:
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:
enable-cache: true
enable-cache: true # zizmor: ignore[cache-poisoning]
version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
- name: Create uv virtual environment
+7 -5
View File
@@ -27,15 +27,17 @@ env:
This issue was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
CLOSE_PR_MESSAGE: >
This PR was closed because it has been stalled for 14 days with no activity.
This PR was closed because it has been stalled for 21 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
WARN_ISSUE_MESSAGE: >
This issue has been automatically marked as stale because it has not had
recent activity (6 months). It will be closed if no further activity occurs.
Any change, comment or update to this issue will reset this count.
Thank you for your contributions.
WARN_PR_MESSAGE: >
This PR has been automatically marked as stale because it has not had
recent activity (6 months). It will be closed if no further activity occurs.
recent activity (1 year). It will be closed if no further activity occurs.
Any change, comment or update to this PR will reset this count.
Thank you for your contributions.
jobs:
@@ -56,10 +58,10 @@ jobs:
stale-pr-label: stale
exempt-issue-labels: never-stale
exempt-pr-labels: never-stale
days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
days-before-issue-stale: 180
days-before-issue-close: 14
days-before-pr-stale: 180
days-before-pr-close: 14
days-before-pr-stale: 365
days-before-pr-close: 21
delete-branch: true
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
+1 -1
View File
@@ -70,7 +70,7 @@ jobs:
echo "Dependencies unbound:" && cat pyproject.toml
- name: Install lerobot with all extras
run: uv sync --all-extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
- name: Run pytest (all extras)
run: uv run pytest tests -vv
+7 -7
View File
@@ -26,7 +26,7 @@ repos:
##### General Code Quality & Formatting #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-added-large-files
args: ['--maxkb=1024']
@@ -39,20 +39,20 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.4
rev: v0.14.1
hooks:
- id: ruff-format
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.34.0
rev: v1.38.1
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.20.0
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]
@@ -68,12 +68,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.27.2
rev: v8.28.0
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.11.0
rev: v1.15.2
hooks:
- id: zizmor
@@ -87,7 +87,7 @@ repos:
# TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing #####
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.0
rev: v1.18.2
hooks:
- id: mypy
args: [--config-file=pyproject.toml]
+1 -1
View File
@@ -137,7 +137,7 @@ Follow these steps to start contributing:
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda or miniconda:
Set up a development environment with conda:
```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
+10 -9
View File
@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
### Environment Setup
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
When using `miniconda`, install `ffmpeg` in your environment:
When using `conda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -185,6 +185,11 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
### Weights & Biases
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
@@ -207,13 +212,13 @@ lerobot-dataset-viz \
--episode-index 0
```
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--local-files-only 1 \
--mode local \
--episode-index 0
```
@@ -310,7 +315,7 @@ To upload these to the hub, run the following:
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
```
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
### Acknowledgment
@@ -337,7 +342,3 @@ If you want, you can cite this work with:
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline)
```
```
+6
View File
@@ -37,8 +37,12 @@
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
- local: groot
title: NVIDIA GR00T N1.5
title: "Policies"
- sections:
- local: envhub
title: Environments from the Hub
- local: il_sim
title: Imitation Learning in Sim
- local: libero
@@ -55,6 +59,8 @@
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
title: "Robot Processors"
- sections:
- local: so101
+418
View File
@@ -0,0 +1,418 @@
# Environment Processors
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
## Why Environment Processors?
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
## The Processing Pipeline
Here's how data flows through the complete processing pipeline during evaluation:
```python
# In lerobot_eval.py rollout() function:
# 1. Raw environment observation (numpy arrays, various formats)
raw_observation = env.step(action)
# 2. Convert numpy to torch, normalize images [0,1]
observation = preprocess_observation(raw_observation)
# 3. Add task metadata (for multi-task environments)
observation = add_envs_task(env, observation)
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
# - Flatten robot states
# - Rotate images to match dataset conventions
# - Handle environment-specific coordinate systems
observation = env_preprocessor(observation)
# 5. POLICY-SPECIFIC preprocessing
# - Normalize with dataset statistics
# - Add batch dimensions
# - Move to GPU
# - Tokenize language instructions
observation = preprocessor(observation)
# 6. Policy inference
action = policy.select_action(observation)
# 7. POLICY-SPECIFIC postprocessing
# - Unnormalize actions
# - Remove batch dimensions
action = postprocessor(action)
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
# - Convert action formats if needed
# - Apply environment-specific constraints
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# 9. Execute in environment
env.step(action)
```
## The Benefits
### 1. **Separation of Concerns**
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
```python
# ❌ Before: Mixed concerns
class LiberoVLAPolicy:
def preprocess(self, obs):
# Environment-specific: Flatten robot state (shouldn't be in policy!)
state = self._flatten_robot_state(obs["robot_state"])
# Policy-specific: Normalize with dataset stats
state = self.normalizer(state)
return state
# ✅ After: Clear separation
# Environment processor: Handles LIBERO's nested robot state
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
# Policy processor: Handles model requirements
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
```
### 2. **Flexibility and Reusability**
The same policy can work with different environment processors, and the same environment processor can work with different policies:
```python
# Use SmolVLA policy with LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
# Or use ACT policy with the same LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
```
### 3. **Easier Experimentation**
Want to try different state representations for LIBERO? Just create a new processor:
```python
# Original: 8D state (pos + quat→axisangle + gripper)
@ProcessorStepRegistry.register("libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
gripper = robot_state["gripper"]["qpos"] # 2D
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
return state
# Experiment: Add velocity for better control
@ProcessorStepRegistry.register("libero_velocity_processor")
class LiberoVelocityProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
# Include velocities for 14D state
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
gripper_pos = robot_state["gripper"]["qpos"] # 2D
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
gripper_pos, gripper_vel], dim=-1) # 14D
return state
```
### 4. **Cleaner Environment Code**
Environments expose **all available data** without needing to know what downstream models will use:
```python
# LIBERO environment exposes full robot state
observation = {
"pixels": {"image": img, "image2": img2},
"robot_state": {
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
"gripper": {"qpos": ..., "qvel": ...},
"joints": {"pos": ..., "vel": ...}
}
}
# Environment processor decides what to use
# Policy processor handles model-specific transformations
```
## Using Environment Processors
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python
from lerobot.envs.factory import make_env_pre_post_processors
from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
```
### Implementation in `envs/factory.py`
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:
```python
def eval_main(cfg: EvalPipelineConfig):
# Create environment
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
# Create policy
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
# Create policy processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
)
# Create environment processors (NEW!)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
# Run evaluation with both processor types
eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor, # Environment-specific
env_postprocessor=env_postprocessor, # Environment-specific
preprocessor=preprocessor, # Policy-specific
postprocessor=postprocessor, # Policy-specific
n_episodes=cfg.eval.n_episodes,
)
```
## Example: LIBERO Environment Processor
The `LiberoProcessorStep` demonstrates a real-world environment processor:
```python
from lerobot.processor.pipeline import ObservationProcessorStep
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
**State Processing:**
- Extracts end-effector position (3D)
- Converts quaternion to axis-angle representation (3D)
- Extracts gripper joint positions (2D)
- Concatenates into 8D state vector
**Image Processing:**
- Rotates images 180° to match HuggingFaceVLA/libero convention
"""
def _process_observation(self, observation):
processed_obs = observation.copy()
# Process images: Flip 180° for camera convention
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = processed_obs[key]
img = torch.flip(img, dims=[2, 3]) # Flip H and W
processed_obs[key] = img
# Process robot_state: Flatten to 8D vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
eef_pos = robot_state["eef"]["pos"] # (B, 3)
eef_quat = robot_state["eef"]["quat"] # (B, 4)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
state = state.float()
processed_obs["observation.state"] = state
return processed_obs
```
### Why These Transformations?
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
- Selects the relevant components (pos, quat, gripper)
- Converts quaternion to axis-angle (more suitable for learning)
- Flattens to a single 8D vector that policies expect
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
## Adding Environment Processors for New Environments
To add environment processors for a new environment:
### 1. Create the Processor Step
```python
# In src/lerobot/processor/env_processor.py
@dataclass
@ProcessorStepRegistry.register(name="myenv_processor")
class MyEnvProcessorStep(ObservationProcessorStep):
"""Process observations from MyEnv."""
def _process_observation(self, observation):
processed = observation.copy()
# Your environment-specific transformations
if "myenv.specific.state" in processed:
state = processed.pop("myenv.specific.state")
# Transform to standard format
processed["observation.state"] = self._transform_state(state)
return processed
```
### 2. Update the Factory
```python
# In src/lerobot/envs/factory.py
def make_env_pre_post_processors(env_cfg: EnvConfig):
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
else:
preprocessor = PolicyProcessorPipeline(steps=[])
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### 3. Use in Evaluation
No changes needed! The evaluation script automatically uses the appropriate processor:
```bash
lerobot-eval \
--policy.path=lerobot/my_policy \
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
--eval.n_episodes=10
```
## Future: Environment Postprocessors
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
### Action Space Transformations
```python
@dataclass
class MyEnvActionPostprocessor(ProcessorStep):
"""Convert policy actions to environment-specific format."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Convert from Cartesian to joint space
if self.action_space == "joint":
action = self.ik_solver(action)
# Example: Apply environment-specific safety limits
action = torch.clamp(action, self.min_action, self.max_action)
transition["action"] = action
return transition
```
### Coordinate System Conversions
```python
@dataclass
class CoordinateTransformPostprocessor(ProcessorStep):
"""Transform actions between coordinate systems."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Policy outputs in world frame, env expects base frame
action = self.world_to_base_transform(action)
transition["action"] = action
return transition
```
## Best Practices
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
## Summary
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
- ✅ Enables easy experimentation with different state representations
- ✅ Allows policies to work seamlessly across different environments
- ✅ Keeps environment code focused on simulation/hardware interface
- ✅ Makes processor pipelines more maintainable and debuggable
- ✅ Follows the single responsibility principle
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.
+424
View File
@@ -0,0 +1,424 @@
# Loading Environments from the Hub
The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community.
## Overview
With EnvHub, you can:
- Load environments from the Hub instantly
- Share your custom simulation tasks with the community
- Version control your environments using Git
- Distribute complex physics simulations without packaging hassles
## Quick Start
Loading an environment from the Hub is as simple as:
```python
from lerobot.envs.factory import make_env
# Load a hub environment (requires explicit consent to run remote code)
env = make_env("lerobot/cartpole-env", trust_remote_code=True)
```
<Tip warning={true}>
**Security Notice**: Loading environments from the Hub executes Python code
from third-party repositories. Only use `trust_remote_code=True` with
repositories you trust. We strongly recommend pinning to a specific commit
hash for reproducibility and security.
</Tip>
## What is EnvHub?
EnvHub is a framework that allows researchers and developers to:
1. **Publish environments** to the Hugging Face Hub as Git repositories
2. **Load environments** dynamically without installing them as packages
3. **Version and track** environment changes using Git semantics
4. **Discover** new simulation tasks shared by the community
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures.
## Repository Structure
To make your environment loadable from the Hub, your repository must contain at minimum:
### Required Files
**`env.py`** (or custom Python file)
- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
- This function should return one of:
- A `gym.vector.VectorEnv` (most common)
- A single `gym.Env` (will be automatically wrapped)
- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks)
### Optional Files
**`requirements.txt`**
- List any additional dependencies your environment needs
- Users will need to install these manually before loading your environment
**`README.md`**
- Document your environment: what task it implements, observation/action spaces, rewards, etc.
- Include usage examples and any special setup instructions
**`.gitignore`**
- Exclude unnecessary files from your repository
### Example Repository Structure
```
my-environment-repo/
├── env.py # Main environment definition (required)
├── requirements.txt # Dependencies (optional)
├── README.md # Documentation (recommended)
├── assets/ # Images, videos, etc. (optional)
│ └── demo.gif
└── configs/ # Config files if needed (optional)
└── task_config.yaml
```
## Creating Your Environment Repository
### Step 1: Define Your Environment
Create an `env.py` file with a `make_env` function:
```python
# env.py
import gymnasium as gym
def make_env(n_envs: int = 1, use_async_envs: bool = False):
"""
Create vectorized environments for your custom task.
Args:
n_envs: Number of parallel environments
use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv
Returns:
gym.vector.VectorEnv or dict mapping suite names to vectorized envs
"""
def _make_single_env():
# Create your custom environment
return gym.make("CartPole-v1")
# Choose vector environment type
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Create vectorized environment
vec_env = env_cls([_make_single_env for _ in range(n_envs)])
return vec_env
```
### Step 2: Test Locally
Before uploading, test your environment locally:
```python
from lerobot.envs.utils import _load_module_from_path, _call_make_env, _normalize_hub_result
# Load your module
module = _load_module_from_path("./env.py")
# Test the make_env function
result = _call_make_env(module, n_envs=2, use_async_envs=False)
normalized = _normalize_hub_result(result)
# Verify it works
suite_name = next(iter(normalized))
env = normalized[suite_name][0]
obs, info = env.reset()
print(f"Observation shape: {obs.shape if hasattr(obs, 'shape') else type(obs)}")
env.close()
```
### Step 3: Upload to the Hub
Upload your repository to Hugging Face:
```bash
# Install huggingface_hub if needed
pip install huggingface_hub
# Login to Hugging Face
huggingface-cli login
# Create a new repository
huggingface-cli repo create my-custom-env --type space --org my-org
# Initialize git and push
git init
git add .
git commit -m "Initial environment implementation"
git remote add origin https://huggingface.co/my-org/my-custom-env
git push -u origin main
```
Alternatively, use the `huggingface_hub` Python API:
```python
from huggingface_hub import HfApi
api = HfApi()
# Create repository
api.create_repo("my-custom-env", repo_type="space")
# Upload files
api.upload_folder(
folder_path="./my-env-folder",
repo_id="username/my-custom-env",
repo_type="space",
)
```
## Loading Environments from the Hub
### Basic Usage
```python
from lerobot.envs.factory import make_env
# Load from the hub
envs_dict = make_env(
"username/my-custom-env",
n_envs=4,
trust_remote_code=True
)
# Access the environment
suite_name = next(iter(envs_dict))
env = envs_dict[suite_name][0]
# Use it like any gym environment
obs, info = env.reset()
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
```
### Advanced: Pinning to Specific Versions
For reproducibility and security, pin to a specific Git revision:
```python
# Pin to a specific branch
env = make_env("username/my-env@main", trust_remote_code=True)
# Pin to a specific commit (recommended for papers/experiments)
env = make_env("username/my-env@abc123def456", trust_remote_code=True)
# Pin to a tag
env = make_env("username/my-env@v1.0.0", trust_remote_code=True)
```
### Custom File Paths
If your environment definition is not in `env.py`:
```python
# Load from a custom file
env = make_env("username/my-env:custom_env.py", trust_remote_code=True)
# Combine with version pinning
env = make_env("username/my-env@v1.0:envs/task_a.py", trust_remote_code=True)
```
### Async Environments
For better performance with multiple environments:
```python
envs_dict = make_env(
"username/my-env",
n_envs=8,
use_async_envs=True, # Use AsyncVectorEnv for parallel execution
trust_remote_code=True
)
```
## URL Format Reference
The hub URL format supports several patterns:
| Pattern | Description | Example |
| -------------------- | ------------------------------ | -------------------------------------- |
| `user/repo` | Load `env.py` from main branch | `make_env("lerobot/pusht-env")` |
| `user/repo@revision` | Load from specific revision | `make_env("lerobot/pusht-env@main")` |
| `user/repo:path` | Load custom file | `make_env("lerobot/envs:pusht.py")` |
| `user/repo@rev:path` | Revision + custom file | `make_env("lerobot/envs@v1:pusht.py")` |
## Multi-Task Environments
For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
```python
def make_env(n_envs: int = 1, use_async_envs: bool = False):
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Return dict: {suite_name: {task_id: VectorEnv}}
return {
"suite_1": {
0: env_cls([lambda: gym.make("Task1-v0") for _ in range(n_envs)]),
1: env_cls([lambda: gym.make("Task2-v0") for _ in range(n_envs)]),
},
"suite_2": {
0: env_cls([lambda: gym.make("Task3-v0") for _ in range(n_envs)]),
}
}
```
## Security Considerations
<Tip warning={true}>
**Important**: The `trust_remote_code=True` flag is required to execute
environment code from the Hub. This is by design for security.
</Tip>
When loading environments from the Hub:
1. **Review the code first**: Visit the repository and inspect `env.py` before loading
2. **Pin to commits**: Use specific commit hashes for reproducibility
3. **Check dependencies**: Review `requirements.txt` for suspicious packages
4. **Use trusted sources**: Prefer official organizations or well-known researchers
5. **Sandbox if needed**: Run untrusted code in isolated environments (containers, VMs)
Example of safe usage:
```python
# ❌ BAD: Loading without inspection
env = make_env("random-user/untrusted-env", trust_remote_code=True)
# ✅ GOOD: Review code, then pin to specific commit
# 1. Visit https://huggingface.co/trusted-org/verified-env
# 2. Review the env.py file
# 3. Copy the commit hash
env = make_env("trusted-org/verified-env@a1b2c3d4", trust_remote_code=True)
```
## Example: CartPole from the Hub
Here's a complete example using the reference CartPole environment:
```python
from lerobot.envs.factory import make_env
import numpy as np
# Load the environment
envs_dict = make_env("lerobot/cartpole-env", n_envs=4, trust_remote_code=True)
# Get the vectorized environment
suite_name = next(iter(envs_dict))
env = envs_dict[suite_name][0]
# Run a simple episode
obs, info = env.reset()
done = np.zeros(env.num_envs, dtype=bool)
total_reward = np.zeros(env.num_envs)
while not done.all():
# Random policy
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
done = terminated | truncated
print(f"Average reward: {total_reward.mean():.2f}")
env.close()
```
## Benefits of EnvHub
### For Environment Authors
- **Easy distribution**: No PyPI packaging required
- **Version control**: Use Git for environment versioning
- **Rapid iteration**: Push updates instantly
- **Documentation**: Hub README renders beautifully
- **Community**: Reach LeRobot users directly
### For Researchers
- **Quick experiments**: Load any environment in one line
- **Reproducibility**: Pin to specific commits
- **Discovery**: Browse environments on the Hub
- **No conflicts**: No need to install conflicting packages
### For the Community
- **Growing ecosystem**: More diverse simulation tasks
- **Standardization**: Common `make_env` API
- **Collaboration**: Fork and improve existing environments
- **Accessibility**: Lower barrier to sharing research
## Troubleshooting
### "Refusing to execute remote code"
You must explicitly pass `trust_remote_code=True`:
```python
env = make_env("user/repo", trust_remote_code=True)
```
### "Module X not found"
The hub environment has dependencies you need to install:
```bash
# Check the repo's requirements.txt and install dependencies
pip install gymnasium numpy
```
### "make_env not found in module"
Your `env.py` must expose a `make_env` function:
```python
def make_env(n_envs: int, use_async_envs: bool):
# Your implementation
pass
```
### Environment returns wrong type
The `make_env` function must return:
- A `gym.vector.VectorEnv`, or
- A single `gym.Env`, or
- A dict `{suite_name: {task_id: VectorEnv}}`
## Best Practices
1. **Document your environment**: Include observation/action space descriptions, reward structure, and termination conditions in your README
2. **Add requirements.txt**: List all dependencies with versions
3. **Test thoroughly**: Verify your environment works locally before pushing
4. **Use semantic versioning**: Tag releases with version numbers
5. **Add examples**: Include usage examples in your README
6. **Keep it simple**: Minimize dependencies when possible
7. **License your work**: Add a LICENSE file to clarify usage terms
## Future Directions
The EnvHub ecosystem enables exciting possibilities:
- **GPU-accelerated physics**: Share Isaac Gym or Brax environments
- **Photorealistic rendering**: Distribute environments with advanced graphics
- **Multi-agent scenarios**: Complex interaction tasks
- **Real-world simulators**: Digital twins of physical setups
- **Procedural generation**: Infinite task variations
- **Domain randomization**: Pre-configured DR pipelines
As more researchers and developers contribute, the diversity and quality of available environments will grow, benefiting the entire robotics learning community.
## See Also
- [Hugging Face Hub Documentation](https://huggingface.co/docs/hub/en/index)
- [Gymnasium Documentation](https://gymnasium.farama.org/index.html)
- [Example Hub Environment](https://huggingface.co/lerobot/cartpole-env)
+125
View File
@@ -0,0 +1,125 @@
# GR00T N1.5 Policy
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
This document outlines the specifics of its integration and usage within the LeRobot framework.
## Model Overview
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
- Real captured data from robots.
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
- Internet-scale video data.
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
## Installation Requirements
As of today, GR00T N1.5 requires flash attention for it's internal working.
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
```bash
# Check https://pytorch.org/get-started/locally/ for your system
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
3. Install LeRobot by running:
```bash
pip install lerobot[groot]
```
## Usage
To use GR00T in your LeRobot configuration, specify the policy type as:
```python
policy.type=groot
```
## Training
### Training Command Example
Here's a complete training command for finetuning the base GR00T model on your own dataset:
```bash
# Using a multi-GPU setup
accelerate launch \
--multi_gpu \
--num_processes=$NUM_GPUS \
$(which lerobot-train) \
--output_dir=$OUTPUT_DIR \
--save_checkpoint=true \
--batch_size=$BATCH_SIZE \
--steps=$NUM_STEPS \
--save_freq=$SAVE_FREQ \
--log_freq=$LOG_FREQ \
--policy.push_to_hub=true \
--policy.type=groot \
--policy.repo_id=$REPO_ID \
--policy.tune_diffusion_model=false \
--dataset.repo_id=$DATASET_ID \
--wandb.enable=true \
--wandb.disable_artifact=true \
--job_name=$JOB_NAME
```
## Performance Results
### Libero Benchmark Results
> [!NOTE]
> Follow our instructions for Libero usage: [Libero](./libero)
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
| **Libero Spatial** | 82.0% | 92.0% |
| **Libero Object** | 99.0% | 92.0% |
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
```bash
lerobot-record \
--robot.type=bi_so100_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
--robot.id=bimanual_follower \
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm"
--policy.path=<user>/groot-bimanual # your trained model
--dataset.episode_time_s=30
--dataset.reset_time_s=10
```
## License
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).
+1 -1
View File
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
HF_USER=$(hf auth whoami | head -n 1)
echo $HF_USER
```
+12 -2
View File
@@ -1,8 +1,15 @@
# Installation
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Environment Setup
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
Create a virtual environment with Python 3.10, using conda:
```bash
conda create -y -n lerobot python=3.10
@@ -14,7 +21,7 @@ Then activate your conda environment, you have to do this each time you open a s
conda activate lerobot
```
When using `miniconda`, install `ffmpeg` in your environment:
When using `conda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -74,6 +81,9 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
+14 -12
View File
@@ -208,34 +208,36 @@ LeRobot supports saving and loading calibration data automatically. This is usef
<!-- prettier-ignore-start -->
```python
> @property
> def is_calibrated(self) -> bool:
> return True
>
> def calibrate(self) -> None:
> pass
> ```
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
```
<!-- prettier-ignore-end -->
### `is_calibrated`
This should reflect whether your robot has the required calibration loaded.
```
<!-- prettier-ignore-end -->python
<!-- prettier-ignore-start -->
```python
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
```
<!-- prettier-ignore-end -->
### `calibrate()`
The goal of the calibration is twofold:
- Know the physical range of motion of each motors in order to only send commands within this range.
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
- Know the physical range of motion of each motors in order to only send commands within this range.
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
<!-- prettier-ignore-start -->
```python
def calibrate(self) -> None:
+5
View File
@@ -28,6 +28,11 @@ As described by Physical Intelligence, while AI has achieved remarkable success
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
+5
View File
@@ -36,6 +36,11 @@ This diverse training mixture creates a "curriculum" that enables generalization
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
+27
View File
@@ -0,0 +1,27 @@
## Research Paper
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
## Repository
Code: https://github.com/NVIDIA/Isaac-GR00T
## Citation
```bibtex
@inproceedings{gr00tn1_2025,
archivePrefix = {arxiv},
eprint = {2503.14734},
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
month = {March},
year = {2025},
booktitle = {ArXiv Preprint},
}
```
## Additional Resources
Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
+6
View File
@@ -0,0 +1,6 @@
python ./examples/dataset/convert_hdf5_lerobot.py \
--src-paths /fsx/jade_choghari/XVLA-Soft-Fold/0808_12am_stage_1_stage2new_new_cam_very_slow_no_sleeve \
--output-path /fsx/jade_choghari/new-data \
--executor local \
--tasks-per-job 3 \
--workers 10
+437
View File
@@ -0,0 +1,437 @@
import argparse
import os
import re
import shutil
from pathlib import Path
import pandas as pd
# import ray
# from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
from datatrove.executor import LocalPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from lerobot.datasets.aggregate import (
aggregate_data,
aggregate_metadata,
aggregate_stats,
aggregate_videos,
validate_all_metadata,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
write_info,
write_stats,
write_tasks,
)
XVLA_SOFT_FOLD_FEATURES = {
"observation.images.cam_high": {
"dtype": "video",
"names": ["height", "width", "channels"],
"shape": (480, 640, 3),
"names": ["height", "width", "rgb"],
},
"observation.images.cam_left_wrist": {
"dtype": "video",
"names": ["height", "width", "channels"],
"shape": (480, 640, 3),
"names": ["height", "width", "rgb"],
},
"observation.images.cam_right_wrist": {
"dtype": "video",
"names": ["height", "width", "channels"],
"shape": (480, 640, 3),
"names": ["height", "width", "rgb"],
},
"observation.states.eef_euler": {
"dtype": "float32",
"shape": (14,), # 14 = 7 joints per arm × 2 arms OR 14-d state representation
"names": {"values": [f"eef_euler_{i}" for i in range(14)]},
},
"observation.states.eef_quaternion": {
"dtype": "float32",
"shape": (16,), # 16 = 8 quaternion floats per arm × 2 arms
"names": {"values": [f"eef_quat_{i}" for i in range(16)]},
},
"observation.states.eef_6d": {
"dtype": "float32",
"shape": (20,), # 20 = pos(3) + rot6d(6) + extra dims
"names": {"values": [f"eef6d_{i}" for i in range(20)]},
},
"observation.states.eef_left_time": {
"dtype": "float32",
"shape": (1,),
"names": {"values": ["eef_left_time"]},
},
"observation.states.eef_right_time": {
"dtype": "float32",
"shape": (1,),
"names": {"values": ["eef_right_time"]},
},
"observation.states.qpos": {
"dtype": "float32",
"shape": (14,), # 7 per arm × 2 arms
"names": {"motors": [f"qpos_{i}" for i in range(14)]},
},
"observation.states.qvel": {
"dtype": "float32",
"shape": (14,),
"names": {"motors": [f"qvel_{i}" for i in range(14)]},
},
"observation.states.effort": {
"dtype": "float32",
"shape": (14,),
"names": {"motors": [f"effort_{i}" for i in range(14)]},
},
"observation.states.qpos_left_time": {
"dtype": "float32",
"shape": (1,),
"names": {"values": ["qpos_left_time"]},
},
"observation.states.qpos_right_time": {
"dtype": "float32",
"shape": (1,),
"names": {"values": ["qpos_right_time"]},
},
"action": {
"dtype": "float32",
"shape": (14,),
"names": {"motors": [f"joint_action_{i}" for i in range(14)]},
},
"time_stamp": {
"dtype": "float32",
"shape": (1,),
"names": {"values": ["global_timestamp"]},
},
}
import cv2
import numpy as np
def decode_image(encoded_array):
# HDF5 gives you an array of uint8 → convert to raw bytes
data = np.asarray(encoded_array, dtype=np.uint8)
img = cv2.imdecode(data, cv2.IMREAD_COLOR) # returns HWC BGR
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # convert to RGB
return img
from pathlib import Path
import numpy as np
from h5py import File
def load_local_episodes(input_h5: Path):
"""
Load one XVLA Soft-Fold episode from a single .hdf5 file.
This dataset stores ONE episode per file, NOT a /data/ group.
"""
import h5py
import numpy as np
with h5py.File(input_h5, "r") as f:
# Determine episode length from any observation vector
episode_len = f["observations/eef_6d"].shape[0]
episode = []
for i in range(episode_len):
frame = {
# ----------------------
# ROOT-LEVEL
# ----------------------
"task": "fold the cloth",
"time_stamp": np.array([f["time_stamp"][i]], dtype=np.float32),
# ----------------------
# OBSERVATIONS
# ----------------------
"observation": {
"images": {
"cam_high": f["observations/images/cam_high"][i],
"cam_left_wrist": f["observations/images/cam_left_wrist"][i],
"cam_right_wrist": f["observations/images/cam_right_wrist"][i],
},
"states": {
"eef_euler": f["observations/eef"][i],
"eef_quaternion": f["observations/eef_quaternion"][i],
"eef_6d": f["observations/eef_6d"][i],
"eef_left_time": np.array([f["observations/eef_left_time"][i]], dtype=np.float32),
"eef_right_time": np.array([f["observations/eef_right_time"][i]], dtype=np.float32),
"qpos": f["observations/qpos"][i],
"qvel": f["observations/qvel"][i],
"effort": f["observations/effort"][i],
"qpos_left_time": np.array([f["observations/qpos_left_time"][i]], dtype=np.float32),
"qpos_right_time": np.array([f["observations/qpos_right_time"][i]], dtype=np.float32),
},
},
# ----------------------
# ACTION (your joint 14-D)
# ----------------------
"action": f["action"][i].astype(np.float32),
}
episode.append(frame)
yield episode
# from ray.runtime_env import RuntimeEnv
from tqdm import tqdm
def setup_logger():
import sys
from datatrove.utils.logging import logger
logger.remove()
logger.add(sys.stdout, level="INFO", colorize=True)
return logger
class SaveLerobotDataset(PipelineStep):
name = "Save Temp LerobotDataset"
type = "libero2lerobot"
def __init__(self, tasks: list[tuple[Path, Path, str]]):
super().__init__()
self.tasks = tasks
def run(self, data=None, rank: int = 0, world_size: int = 1):
logger = setup_logger()
input_h5, output_path, task_instruction = self.tasks[rank]
if output_path.exists():
shutil.rmtree(output_path)
dataset = LeRobotDataset.create(
repo_id=f"{input_h5.parent.name}/{input_h5.name}",
root=output_path,
fps=20,
robot_type="franka",
features=XVLA_SOFT_FOLD_FEATURES,
)
logger.info(f"start processing for {input_h5}, saving to {output_path}")
raw_dataset = load_local_episodes(input_h5)
for episode_index, episode_data in enumerate(raw_dataset):
with self.track_time("saving episode"):
for raw_frame in episode_data:
frame_data = {
"task": task_instruction,
# ---------------------- IMAGES ----------------------
"observation.images.cam_high": decode_image(raw_frame["observation"]["images"]["cam_high"]),
"observation.images.cam_left_wrist": decode_image(raw_frame["observation"]["images"]["cam_left_wrist"]),
"observation.images.cam_right_wrist": decode_image(raw_frame["observation"]["images"]["cam_right_wrist"]),
# ---------------------- EEF STATES ----------------------
"observation.states.eef_euler": raw_frame["observation"]["states"]["eef_euler"],
"observation.states.eef_quaternion": raw_frame["observation"]["states"]["eef_quaternion"],
"observation.states.eef_6d": raw_frame["observation"]["states"]["eef_6d"],
"observation.states.eef_left_time": raw_frame["observation"]["states"]["eef_left_time"],
"observation.states.eef_right_time": raw_frame["observation"]["states"]["eef_right_time"],
# ---------------------- JOINT STATES ----------------------
"observation.states.qpos": raw_frame["observation"]["states"]["qpos"],
"observation.states.qvel": raw_frame["observation"]["states"]["qvel"],
"observation.states.effort": raw_frame["observation"]["states"]["effort"],
"observation.states.qpos_left_time": raw_frame["observation"]["states"]["qpos_left_time"],
"observation.states.qpos_right_time": raw_frame["observation"]["states"]["qpos_right_time"],
# ---------------------- ACTION ----------------------
"action": raw_frame["action"],
# ---------------------- TIME ----------------------
"time_stamp": raw_frame["time_stamp"],
}
dataset.add_frame(frame_data)
dataset.save_episode()
logger.info(f"Processed {dataset.repo_id}, episode {episode_index}, len={len(episode_data)}")
def create_aggr_dataset(raw_dirs: list[Path], aggregated_dir: Path):
logger = setup_logger()
all_metadata = [LeRobotDatasetMetadata("", root=raw_dir) for raw_dir in raw_dirs]
fps, robot_type, features = validate_all_metadata(all_metadata)
if aggregated_dir.exists():
shutil.rmtree(aggregated_dir)
aggr_meta = LeRobotDatasetMetadata.create(
repo_id=f"{aggregated_dir.parent.name}/{aggregated_dir.name}",
root=aggregated_dir,
fps=fps,
robot_type=robot_type,
features=features,
)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
meta_idx = {"chunk": 0, "file": 0}
data_idx = {"chunk": 0, "file": 0}
videos_idx = {key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys}
aggr_meta.episodes = {}
for src_meta in tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(
src_meta, aggr_meta, videos_idx, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE
)
data_idx = aggregate_data(src_meta, aggr_meta, data_idx, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE)
meta_idx = aggregate_metadata(src_meta, aggr_meta, meta_idx, data_idx, videos_idx)
aggr_meta.info["total_episodes"] += src_meta.total_episodes
aggr_meta.info["total_frames"] += src_meta.total_frames
logger.info("write tasks")
write_tasks(aggr_meta.tasks, aggr_meta.root)
logger.info("write info")
aggr_meta.info.update(
{
"total_tasks": len(aggr_meta.tasks),
"total_episodes": sum(m.total_episodes for m in all_metadata),
"total_frames": sum(m.total_frames for m in all_metadata),
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
}
)
write_info(aggr_meta.info, aggr_meta.root)
logger.info("write stats")
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
write_stats(aggr_meta.stats, aggr_meta.root)
def delete_temp_data(temp_dirs: list[Path]):
logger = setup_logger()
logger.info("Delete temp data_dir")
for temp_dir in temp_dirs:
shutil.rmtree(temp_dir)
def main(
src_paths: list[Path],
output_path: Path,
executor: str,
cpus_per_task: int,
tasks_per_job: int,
workers: int,
resume_dir: Path = None,
debug: bool = False,
repo_id: str = None,
push_to_hub: bool = False,
):
tasks = []
for src_path in src_paths:
for input_h5 in src_path.glob("*.hdf5"):
tasks.append(
(
input_h5,
(output_path / (src_path.name + "_temp") / input_h5.stem).resolve(),
"fold the cloth", # fixed single task
)
)
if len(src_paths) > 1:
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
else:
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
aggregate_output_path = aggregate_output_path.resolve()
if debug:
executor = "local"
workers = 1
tasks = tasks[:2]
push_to_hub = False
match executor:
case "local":
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
executor = LocalPipelineExecutor
# case "ray":
# runtime_env = RuntimeEnv(
# env_vars={
# "HDF5_USE_FILE_LOCKING": "FALSE",
# "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
# "SVT_LOG": "1",
# },
# )
# ray.init(runtime_env=runtime_env)
# executor = RayPipelineExecutor
case _:
raise ValueError(f"Executor {executor} not supported")
executor_config = {
"tasks": len(tasks),
"workers": workers,
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if False else {}),
}
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_dir).run()
create_aggr_dataset([task[1] for task in tasks], aggregate_output_path)
delete_temp_data([task[1] for task in tasks])
for task in tasks:
shutil.rmtree(task[1].parent, ignore_errors=True)
if push_to_hub:
assert repo_id is not None
tags = ["LeRobot", "libero", "franka"]
tags.extend([src_path.name for src_path in src_paths])
LeRobotDataset(
repo_id=repo_id,
root=aggregate_output_path,
).push_to_hub(
tags=tags,
private=False,
push_videos=True,
license="apache-2.0",
upload_large_folder=False,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src-paths", type=Path, nargs="+", required=True)
parser.add_argument("--output-path", type=Path, required=True)
parser.add_argument("--executor", type=str, choices=["local", "ray"], default="local")
parser.add_argument("--cpus-per-task", type=int, default=1)
parser.add_argument("--tasks-per-job", type=int, default=1, help="number of concurrent tasks per job, only used for ray")
parser.add_argument("--workers", type=int, default=-1, help="number of concurrent jobs to run")
parser.add_argument("--resume-dir", type=Path, help="logs directory to resume")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True")
parser.add_argument("--push-to-hub", action="store_true", help="upload to hub")
args = parser.parse_args()
main(**vars(args))
+12 -14
View File
@@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break
if __name__ == "__main__":
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break
+36
View File
@@ -0,0 +1,36 @@
from pathlib import Path
import numpy as np
from h5py import File
def load_local_episodes(input_h5: Path):
with File(input_h5, "r") as f:
for demo in f["data"].values():
demo_len = len(demo["obs/agentview_rgb"])
# (-1: open, 1: close) -> (0: close, 1: open)
action = np.array(demo["actions"])
action = np.concatenate(
[
action[:, :6],
(1 - np.clip(action[:, -1], 0, 1))[:, None],
],
axis=1,
)
state = np.concatenate(
[
np.array(demo["obs/ee_states"]),
np.array(demo["obs/gripper_states"]),
],
axis=1,
)
episode = {
"observation.images.image": np.array(demo["obs/agentview_rgb"]),
"observation.images.wrist_image": np.array(demo["obs/eye_in_hand_rgb"]),
"observation.state": np.array(state, dtype=np.float32),
"observation.states.ee_state": np.array(demo["obs/ee_states"], dtype=np.float32),
"observation.states.joint_state": np.array(demo["obs/joint_states"], dtype=np.float32),
"observation.states.gripper_state": np.array(demo["obs/gripper_states"], dtype=np.float32),
"action": np.array(action, dtype=np.float32),
}
yield [{**{k: v[i] for k, v in episode.items()}} for i in range(demo_len)]
@@ -15,16 +15,12 @@
# limitations under the License.
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
from port_droid import DROID_SHARDS
class AggregateDatasets(PipelineStep):
@@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep):
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
init_logging()
# Since aggregate_datasets already handles parallel processing internally,
+2 -2
View File
@@ -20,7 +20,7 @@ from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
from port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging
+9 -3
View File
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id),
UploadDataset(repo_id, private=private),
],
"logging_dir": str(logs_dir / job_name),
}
@@ -267,6 +267,12 @@ def main():
default="1950M",
help="Memory per cpu that each worker will use.",
)
parser.add_argument(
"--private",
action="store_true",
default=False,
help="Whether to create a private repository.",
)
init_logging()
@@ -0,0 +1,98 @@
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]
return [i / fps for i in delta_indices]
output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
@@ -0,0 +1,57 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
@@ -0,0 +1,11 @@
from lerobot.async_inference.configs import PolicyServerConfig
from lerobot.async_inference.policy_server import serve
host = ... # something like "127.0.0.1" if you're exposing to localhost
port = ... # something like 8080
config = PolicyServerConfig(
host=host,
port=port,
)
serve(config)
@@ -0,0 +1,55 @@
import threading
from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.helpers import visualize_action_queue_size
from lerobot.async_inference.robot_client import RobotClient
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.robots.so100_follower import SO100FollowerConfig
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
# check the config.json on the Hub for the policy you are using to see the expected camera specs
camera_cfg = {
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
server_address = ... # something like "127.0.0.1:8080" if using localhost
# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
policy_type="act",
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
# 4. Create and start client
client = RobotClient(client_cfg)
# 5. Provide a textual description of the task
task = ...
if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()
try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)
@@ -0,0 +1,99 @@
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]
return [i / fps for i in delta_indices]
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
policy = DiffusionPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
@@ -0,0 +1,60 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_diffusion"
model = DiffusionPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(
model.config, model_id, dataset_stats=dataset_metadata.stats
)
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
@@ -0,0 +1,67 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/pi0_base"
model = PI0Policy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
+345
View File
@@ -0,0 +1,345 @@
import multiprocessing as mp
import signal
from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
from lerobot.teleoperators.utils import TeleopEvents
LOG_EVERY = 10
SEND_EVERY = 10
def run_learner(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_learner: SACPolicy,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer,
lr: float = 3e-4,
batch_size: int = 32,
device: torch.device = "mps",
):
"""The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
for the actor to adopt."""
policy_learner.train()
policy_learner.to(device)
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
training_step = 0
while not shutdown_event.is_set():
# retrieve incoming transitions from the actor process
try:
transitions = transitions_queue.get(timeout=0.1)
for transition in transitions:
# HIL-SERL: Add ALL transitions to online buffer
online_buffer.add(**transition)
# HIL-SERL: Add ONLY human intervention transitions to offline buffer
is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
if is_intervention:
offline_buffer.add(**transition)
print(
f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
)
except Empty:
pass # No transitions available, continue
# Train if we have enough data
if len(online_buffer) >= policy_learner.config.online_step_before_learning:
# Sample from online buffer (autonomous + human data)
online_batch = online_buffer.sample(batch_size // 2)
# Sample from offline buffer (human demonstrations only, either precollected or at runtime)
offline_batch = offline_buffer.sample(batch_size // 2)
# Combine batches - this is the key HIL-SERL mechanism!
batch = {}
for key in online_batch:
if key in offline_batch:
batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
else:
batch[key] = online_batch[key]
loss, _ = policy_learner.forward(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_step += 1
if training_step % LOG_EVERY == 0:
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
pass
print("[LEARNER] Learner process finished")
def run_actor(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_actor: SACPolicy,
reward_classifier: Classifier,
env_cfg: HILSerlRobotEnvConfig,
device: torch.device = "mps",
output_directory: Path | None = None,
):
"""The actor process - interacts with environment and collects data.
The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
policy_actor.eval()
policy_actor.to(device)
reward_classifier.eval()
reward_classifier.to(device)
# Create robot environment inside the actor process
env, teleop_device = make_robot_env(env_cfg)
try:
for episode in range(MAX_EPISODES):
if shutdown_event.is_set():
break
obs, _info = env.reset()
episode_reward = 0.0
step = 0
episode_transitions = []
print(f"[ACTOR] Starting episode {episode + 1}")
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
next_obs, _env_reward, terminated, truncated, _info = env.step(action)
done = terminated or truncated
# Predict reward
policy_next_obs = make_policy_obs(next_obs, device=device)
reward = reward_classifier.predict_reward(policy_next_obs)
if reward >= 1.0 and not done: # success detected! halt episode
terminated = True
done = True
# In HIL-SERL, human interventions come from the teleop device
is_intervention = False
if hasattr(teleop_device, "get_teleop_events"):
# Real intervention detection from teleop device
teleop_events = teleop_device.get_teleop_events()
is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
# Store transition with intervention metadata
transition = {
"state": policy_obs,
"action": action,
"reward": float(reward) if hasattr(reward, "item") else reward,
"next_state": policy_next_obs,
"done": done,
"truncated": truncated,
"complementary_info": {
"is_intervention": is_intervention,
},
}
episode_transitions.append(transition)
episode_reward += reward
step += 1
obs = next_obs
if done:
break
# Send episode transitions to learner
transitions_queue.put_nowait(episode_transitions)
except KeyboardInterrupt:
print("[ACTOR] Interrupted by user")
finally:
# Clean up
if hasattr(env, "robot") and env.robot.is_connected:
env.robot.disconnect()
if teleop_device and hasattr(teleop_device, "disconnect"):
teleop_device.disconnect()
if output_directory is not None:
policy_actor.save_pretrained(output_directory)
print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
print("[ACTOR] Actor process finished")
def make_policy_obs(obs, device: torch.device = "cpu"):
return {
"observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
**{
f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
for k in obs["pixels"]
},
}
"""Main function - coordinates actor and learner processes."""
device = "mps" # or "cuda" or "cpu"
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
output_directory.mkdir(parents=True, exist_ok=True)
# find ports using lerobot-find-port
follower_port = ...
leader_port = ...
# the robot ids are used the load the right calibration files
follower_id = ...
leader_id = ...
# A pretrained model (to be used in-distribution!)
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
reward_classifier.to(device)
reward_classifier.eval()
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
# Create robot environment
env, teleop_device = make_robot_env(env_cfg)
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create SAC policy for action selection
policy_cfg = SACConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
# Online buffer: initialized from scratch
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
)
# Create communication channels between learner and actor processes
transitions_queue = mp.Queue(maxsize=10)
parameters_queue = mp.Queue(maxsize=2)
shutdown_event = mp.Event()
# Signal handler for graceful shutdown
def signal_handler(sig):
print(f"\nSignal {sig} received, shutting down...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Create processes
learner_process = mp.Process(
target=run_learner,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_learner,
online_replay_buffer,
offline_replay_buffer,
),
kwargs={"device": device}, # can run on accelerated hardware for training
)
actor_process = mp.Process(
target=run_actor,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_actor,
reward_classifier,
env_cfg,
output_directory,
),
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
)
learner_process.start()
actor_process.start()
try:
# Wait for actor to finish (it controls the episode loop)
actor_process.join()
shutdown_event.set()
learner_process.join(timeout=10)
except KeyboardInterrupt:
print("Main process interrupted")
shutdown_event.set()
actor_process.join(timeout=5)
learner_process.join(timeout=10)
finally:
if learner_process.is_alive():
learner_process.terminate()
if actor_process.is_alive():
actor_process.terminate()
@@ -0,0 +1,62 @@
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
# Device to use for training
device = "mps" # or "cuda", or "cpu"
# Load the dataset used for training
repo_id = "lerobot/example_hil_serl_dataset"
dataset = LeRobotDataset(repo_id)
# Configure the policy to extract features from the image frames
camera_keys = dataset.meta.camera_keys
config = RewardClassifierConfig(
num_cameras=len(camera_keys),
device=device,
# backbone model to extract features from the image frames
model_name="microsoft/resnet-18",
)
# Make policy, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta)
optimizer = config.get_optimizer_preset().build(policy.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
# Instantiate a dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
total_loss = 0
total_accuracy = 0
for batch in dataloader:
# Preprocess the batch and move it to the correct device.
batch = preprocessor(batch)
# Forward pass
loss, output_dict = policy.forward(batch)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_accuracy += output_dict["accuracy"]
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_accuracy / len(dataloader)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
print("Training finished!")
# You can now save the trained policy.
policy.push_to_hub(classifier_id)
@@ -0,0 +1,66 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/smolvla_base"
model = SmolVLAPolicy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
+37 -27
View File
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.3.4"
version = "0.4.2"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -74,15 +74,15 @@ dependencies = [
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pyserial>=3.5,<4.0",
"wandb>=0.20.0,<0.23.0",
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.0.0",
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
"gymnasium>=1.1.1,<2.0.0",
"rerun-sdk>=0.24.0,<0.27.0",
# Support dependencies
"deepdiff>=7.0.1,<9.0.0",
@@ -97,7 +97,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -113,17 +113,23 @@ intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
# stretch = [
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
# ] # TODO: Currently not supported
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
"Pillow>=10.0.0,<13.0.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
@@ -136,8 +142,8 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
metaworld = ["metaworld>=3.0.0"]
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
metaworld = ["metaworld==3.0.0"]
# All
all = [
@@ -150,6 +156,7 @@ all = [
"lerobot[intelrealsense]",
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -234,9 +241,6 @@ exclude_dirs = [
"tests",
"benchmarks",
"src/lerobot/datasets/push_dataset_to_hub",
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
"src/lerobot/policies/pi0/conversion_scripts",
"src/lerobot/scripts/push_dataset_to_hub.py",
]
skips = ["B101", "B311", "B404", "B603", "B615"]
@@ -251,6 +255,8 @@ default.extend-ignore-identifiers-re = [
"pn",
"ser",
"ein",
"thw",
"inpt",
]
# TODO: Uncomment when ready to use
@@ -289,7 +295,6 @@ ignore_errors = true
[[tool.mypy.overrides]]
module = "lerobot.envs.*"
# Enable type checking only for the envs module
ignore_errors = false
@@ -297,17 +302,22 @@ ignore_errors = false
# module = "lerobot.utils.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.configs.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.configs.*"
ignore_errors = false
# extra strictness for configs
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
# [[tool.mypy.overrides]]
# module = "lerobot.optim.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.model.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.model.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.processor.*"
@@ -317,9 +327,9 @@ ignore_errors = false
# module = "lerobot.datasets.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.cameras.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.cameras.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
+325 -120
View File
@@ -1,3 +1,4 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
@@ -12,47 +13,62 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
accelerate==1.9.0
# via lerobot
# tensorboard
accelerate==1.11.0
# via
# lerobot
# peft
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.12.15
aiohttp==3.13.1
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
# via
# starlette
# watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.3.0
attrs==25.4.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.0.0
av==15.1.0
# via lerobot
blinker==1.9.0
# via flask
certifi==2025.7.14
bddl==1.0.1
# via libero
certifi==2025.10.5
# via
# requests
# sentry-sdk
cffi==1.17.1
cffi==2.0.0
# via pymunk
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.4.2
charset-normalizer==3.4.4
# via requests
click==8.2.1
click==8.3.0
# via
# flask
# uvicorn
# wandb
cloudpickle==3.1.1
# via gymnasium
cmake==4.0.3
# via
# gymnasium
# libero
cmake==4.1.0
# via lerobot
cmeel==0.57.3
# via
@@ -94,27 +110,27 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.10.1
coverage[toml]==7.11.0
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==3.6.0
datasets==4.1.1
# via lerobot
debugpy==1.8.15
debugpy==1.8.17
# via lerobot
decorator==5.2.1
# via ipython
deepdiff==8.5.0
deepdiff==8.6.1
# via lerobot
diffusers==0.34.0
diffusers==0.35.2
# via lerobot
dill==0.3.8
dill==0.4.0
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.14
dm-control==1.0.34
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -122,29 +138,45 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.7.31
dynamixel-sdk==3.8.4
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via lerobot
# via
# lerobot
# libero
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
# via mujoco
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.0
executing==2.2.1
# via stack-data
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.18.0
filelock==3.20.0
# via
# datasets
# diffusers
@@ -152,24 +184,25 @@ filelock==3.18.0
# torch
# transformers
# virtualenv
flask==3.1.1
# via lerobot
fonttools==4.59.0
fonttools==4.60.1
# via matplotlib
frozenlist==1.7.0
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.3.0
fsspec[http]==2025.9.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
glfw==2.9.0
glfw==2.10.0
# via
# dm-control
# mujoco
@@ -177,61 +210,79 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
# reachy2-sdk-api
gym-aloha==0.1.3
# via lerobot
gym-aloha==0.1.1
gym-hil==0.1.13
# via lerobot
gym-hil==0.1.10
gym-pusht==0.1.6
# via lerobot
gym-pusht==0.1.5
# via lerobot
gym-xarm==0.1.1
# via lerobot
gymnasium==0.29.1
gymnasium==1.2.1
# via
# gym-aloha
# gym-hil
# gym-pusht
# gym-xarm
# gymnasium-robotics
# lerobot
# pettingzoo
gymnasium-robotics==1.2.4
# via gym-xarm
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
hebi-py==2.11.0
# via lerobot
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.5
hf-xet==1.1.10
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
huggingface-hub[cli,hf-transfer]==0.34.3
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
identify==2.6.12
hydra-core==1.3.2
# via libero
identify==2.6.15
# via pre-commit
idna==3.10
idna==3.11
# via
# anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
# gymnasium-robotics
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via imageio
# via
# imageio
# robomimic
importlib-metadata==8.7.0
# via diffusers
iniconfig==2.1.0
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -239,50 +290,71 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
itsdangerous==2.2.0
# via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
# via
# flask
# gymnasium-robotics
# torch
# via torch
jsonlines==4.0.0
# via lerobot
kiwisolver==1.4.8
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
lxml==6.0.0
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
lxml==6.0.2
# via dm-control
markupsafe==3.0.2
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
markupsafe==3.0.3
# via
# flask
# jinja2
# werkzeug
matplotlib==3.10.5
# via lerobot
matplotlib-inline==0.1.7
matplotlib==3.10.7
# via
# lerobot
# libero
matplotlib-inline==0.2.1
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
metaworld==3.0.0
# via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==2.3.7
mujoco==3.3.7
# via
# dm-control
# gym-aloha
# gym-hil
# gym-xarm
# gymnasium-robotics
multidict==6.6.3
# libero
# metaworld
# robosuite
multidict==6.7.0
# via
# aiohttp
# yarl
@@ -290,17 +362,25 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
@@ -309,25 +389,43 @@ numpy==2.2.6
# dm-env
# dm-tree
# gymnasium
# gymnasium-robotics
# h5py
# hebi-py
# imageio
# labmaze
# libero
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
# pettingzoo
# peft
# pyquaternion
# reachy2-sdk
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
# via gym-pusht
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -337,53 +435,63 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.1
pandas==2.3.3
# via
# datasets
# lerobot
parso==0.8.4
parso==0.8.5
# via jedi
pettingzoo==1.24.3
# via gymnasium-robotics
peft==0.17.1
# via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==11.3.0
pillow==12.0.0
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
platformdirs==4.3.8
platformdirs==4.5.0
# via
# jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.2.0
pre-commit==4.3.0
# via lerobot
prompt-toolkit==3.0.51
prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
propcache==0.3.2
propcache==0.4.1
# via
# aiohttp
# yarl
@@ -392,11 +500,17 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.0.0
psutil==7.1.1
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -405,11 +519,13 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
pycparser==2.22
pycparser==2.23
# via cffi
pydantic==2.11.7
# via wandb
pydantic-core==2.33.2
pydantic==2.12.3
# via
# fastapi
# wandb
pydantic-core==2.41.4
# via pydantic
pygame==2.6.1
# via
@@ -424,40 +540,42 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.2.12
pyngrok==7.4.1
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyobjc-core==11.1
pyobjc-core==12.0
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-cocoa
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-applicationservices==11.1
pyobjc-framework-applicationservices==12.0
# via pynput
pyobjc-framework-cocoa==11.1
pyobjc-framework-cocoa==12.0
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-coretext==11.1
pyobjc-framework-coretext==12.0
# via pyobjc-framework-applicationservices
pyobjc-framework-quartz==11.1
pyobjc-framework-quartz==12.0
# via
# pynput
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
pyopengl==3.1.9
pyopengl==3.1.10
# via
# dm-control
# mujoco
pyparsing==3.2.3
pyparsing==3.2.5
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2-macosx==2.54.2
# via lerobot
pyserial==3.5
@@ -465,12 +583,14 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
pytest==8.4.1
pytest==8.4.2
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
pytest-cov==6.2.1
# teleop
pytest-cov==7.0.0
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -478,46 +598,73 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
python-dotenv==1.1.1
# via uvicorn
pytz==2025.2
# via pandas
pyyaml==6.0.2
pyyaml==6.0.3
# via
# accelerate
# datasets
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
pyzmq==27.0.0
pyzmq==27.1.0
# via
# lerobot
# meshcat
regex==2025.7.34
reachy2-sdk==1.0.14
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
# via
# diffusers
# transformers
requests==2.32.4
requests==2.32.5
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
# teleop
# transformers
# wandb
rerun-sdk==0.22.1
rerun-sdk==0.26.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
safetensors==0.5.3
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
@@ -526,10 +673,12 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.34.1
sentry-sdk==2.42.1
# via wandb
shapely==2.1.1
shapely==2.1.2
# via gym-pusht
six==1.17.0
# via
@@ -537,64 +686,106 @@ six==1.17.0
# python-dateutil
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
# via fastapi
sympy==1.14.0
# via torch
termcolor==3.1.0
teleop==0.1.2
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
tifffile==2025.5.10
# via scikit-image
tokenizers==0.21.4
timm==1.0.20
# via lerobot
tokenizers==0.22.1
# via transformers
toml==0.10.2
# via draccus
tomli==2.2.1
tomli==2.3.0
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
# via
# accelerate
# lerobot
# peft
# robomimic
# thop
# timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
# via lerobot
tornado==6.5.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
transformers==4.51.3
# via lerobot
typing-extensions==4.14.1
# nbformat
transformers==4.57.1
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.1
typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via pandas
@@ -604,22 +795,36 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
virtualenv==20.32.0
uvicorn[standard]==0.38.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
# via pre-commit
wandb==0.21.0
# via lerobot
wcwidth==0.2.13
wandb==0.21.4
# via
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
# via uvicorn
werkzeug==3.1.3
# via flask
wrapt==1.17.2
# via tensorboard
wrapt==2.0.0
# via dm-tree
xxhash==3.5.0
xxhash==3.6.0
# via datasets
yarl==1.20.1
yarl==1.22.0
# via aiohttp
zipp==3.23.0
# via importlib-metadata
# via
# etils
# importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools
+325 -114
View File
@@ -13,47 +13,62 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
accelerate==1.9.0
# via lerobot
# tensorboard
accelerate==1.11.0
# via
# lerobot
# peft
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.12.15
aiohttp==3.13.1
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
# via
# starlette
# watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.3.0
attrs==25.4.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.0.0
av==15.1.0
# via lerobot
blinker==1.9.0
# via flask
certifi==2025.7.14
bddl==1.0.1
# via libero
certifi==2025.10.5
# via
# requests
# sentry-sdk
cffi==1.17.1
cffi==2.0.0
# via pymunk
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.4.2
charset-normalizer==3.4.4
# via requests
click==8.2.1
click==8.3.0
# via
# flask
# uvicorn
# wandb
cloudpickle==3.1.1
# via gymnasium
cmake==4.0.3
# via
# gymnasium
# libero
cmake==4.1.0
# via lerobot
cmeel==0.57.3
# via
@@ -95,27 +110,29 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.10.1
coverage[toml]==7.11.0
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==3.6.0
datasets==4.1.1
# via lerobot
debugpy==1.8.15
debugpy==1.8.17
# via lerobot
decorator==5.2.1
# via ipython
deepdiff==8.5.0
decord==0.6.0
# via lerobot
diffusers==0.34.0
deepdiff==8.6.1
# via lerobot
dill==0.3.8
diffusers==0.35.2
# via lerobot
dill==0.4.0
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.14
dm-control==1.0.34
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -123,31 +140,48 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.7.31
dynamixel-sdk==3.8.4
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via lerobot
# via
# flash-attn
# lerobot
# libero
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
# via mujoco
evdev==1.9.2
# via pynput
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.0
executing==2.2.1
# via stack-data
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.18.0
filelock==3.20.0
# via
# datasets
# diffusers
@@ -155,24 +189,27 @@ filelock==3.18.0
# torch
# transformers
# virtualenv
flask==3.1.1
flash-attn==2.8.3
# via lerobot
fonttools==4.59.0
fonttools==4.60.1
# via matplotlib
frozenlist==1.7.0
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.3.0
fsspec[http]==2025.9.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
glfw==2.9.0
glfw==2.10.0
# via
# dm-control
# mujoco
@@ -180,61 +217,79 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
# reachy2-sdk-api
gym-aloha==0.1.3
# via lerobot
gym-aloha==0.1.1
gym-hil==0.1.13
# via lerobot
gym-hil==0.1.10
gym-pusht==0.1.6
# via lerobot
gym-pusht==0.1.5
# via lerobot
gym-xarm==0.1.1
# via lerobot
gymnasium==0.29.1
gymnasium==1.2.1
# via
# gym-aloha
# gym-hil
# gym-pusht
# gym-xarm
# gymnasium-robotics
# lerobot
# pettingzoo
gymnasium-robotics==1.2.4
# via gym-xarm
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
hebi-py==2.11.0
# via lerobot
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.5
hf-xet==1.1.10
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
huggingface-hub[cli,hf-transfer]==0.34.3
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
identify==2.6.12
hydra-core==1.3.2
# via libero
identify==2.6.15
# via pre-commit
idna==3.10
idna==3.11
# via
# anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
# gymnasium-robotics
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via imageio
# via
# imageio
# robomimic
importlib-metadata==8.7.0
# via diffusers
iniconfig==2.1.0
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -242,50 +297,71 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
itsdangerous==2.2.0
# via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
# via
# flask
# gymnasium-robotics
# torch
# via torch
jsonlines==4.0.0
# via lerobot
kiwisolver==1.4.8
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
lxml==6.0.0
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
lxml==6.0.2
# via dm-control
markupsafe==3.0.2
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
markupsafe==3.0.3
# via
# flask
# jinja2
# werkzeug
matplotlib==3.10.5
# via lerobot
matplotlib-inline==0.1.7
matplotlib==3.10.7
# via
# lerobot
# libero
matplotlib-inline==0.2.1
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
metaworld==3.0.0
# via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==2.3.7
mujoco==3.3.7
# via
# dm-control
# gym-aloha
# gym-hil
# gym-xarm
# gymnasium-robotics
multidict==6.6.3
# libero
# metaworld
# robosuite
multidict==6.7.0
# via
# aiohttp
# yarl
@@ -293,42 +369,63 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
# decord
# diffusers
# dm-control
# dm-env
# dm-tree
# gymnasium
# gymnasium-robotics
# h5py
# hebi-py
# imageio
# labmaze
# libero
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
# pettingzoo
# peft
# pyquaternion
# reachy2-sdk
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
nvidia-cublas-cu12==12.6.4.1
# via
# nvidia-cudnn-cu12
@@ -366,8 +463,14 @@ nvidia-nvjitlink-cu12==12.6.85
# torch
nvidia-nvtx-cu12==12.6.77
# via torch
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
# via gym-pusht
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -377,53 +480,63 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.1
pandas==2.3.3
# via
# datasets
# lerobot
parso==0.8.4
parso==0.8.5
# via jedi
pettingzoo==1.24.3
# via gymnasium-robotics
peft==0.17.1
# via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==11.3.0
pillow==12.0.0
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
platformdirs==4.3.8
platformdirs==4.5.0
# via
# jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.2.0
pre-commit==4.3.0
# via lerobot
prompt-toolkit==3.0.51
prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
propcache==0.3.2
propcache==0.4.1
# via
# aiohttp
# yarl
@@ -432,11 +545,17 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.0.0
psutil==7.1.1
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -445,11 +564,13 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
pycparser==2.22
pycparser==2.23
# via cffi
pydantic==2.11.7
# via wandb
pydantic-core==2.33.2
pydantic==2.12.3
# via
# fastapi
# wandb
pydantic-core==2.41.4
# via pydantic
pygame==2.6.1
# via
@@ -464,20 +585,22 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.2.12
pyngrok==7.4.1
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyopengl==3.1.9
pyopengl==3.1.10
# via
# dm-control
# mujoco
pyparsing==3.2.3
pyparsing==3.2.5
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2==2.56.5.9235
# via lerobot
pyserial==3.5
@@ -485,12 +608,14 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
pytest==8.4.1
pytest==8.4.2
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
pytest-cov==6.2.1
# teleop
pytest-cov==7.0.0
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -498,48 +623,75 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
python-dotenv==1.1.1
# via uvicorn
python-xlib==0.33
# via pynput
pytz==2025.2
# via pandas
pyyaml==6.0.2
pyyaml==6.0.3
# via
# accelerate
# datasets
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
pyzmq==27.0.0
pyzmq==27.1.0
# via
# lerobot
# meshcat
regex==2025.7.34
reachy2-sdk==1.0.14
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
# via
# diffusers
# transformers
requests==2.32.4
requests==2.32.5
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
# teleop
# transformers
# wandb
rerun-sdk==0.22.1
rerun-sdk==0.26.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
safetensors==0.5.3
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
@@ -548,10 +700,12 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.34.1
sentry-sdk==2.42.1
# via wandb
shapely==2.1.1
shapely==2.1.2
# via gym-pusht
six==1.17.0
# via
@@ -560,66 +714,109 @@ six==1.17.0
# python-xlib
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
# via fastapi
sympy==1.14.0
# via torch
termcolor==3.1.0
teleop==0.1.2
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
tifffile==2025.5.10
# via scikit-image
tokenizers==0.21.4
timm==1.0.20
# via lerobot
tokenizers==0.22.1
# via transformers
toml==0.10.2
# via draccus
tomli==2.2.1
tomli==2.3.0
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
# via
# accelerate
# flash-attn
# lerobot
# peft
# robomimic
# thop
# timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
# via lerobot
tornado==6.5.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
transformers==4.51.3
# via lerobot
# nbformat
transformers==4.57.1
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
triton==3.3.1
# via torch
typing-extensions==4.14.1
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.1
typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via pandas
@@ -629,22 +826,36 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
virtualenv==20.32.0
uvicorn[standard]==0.38.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
# via pre-commit
wandb==0.21.0
# via lerobot
wcwidth==0.2.13
wandb==0.21.4
# via
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
# via uvicorn
werkzeug==3.1.3
# via flask
wrapt==1.17.2
# via tensorboard
wrapt==2.0.0
# via dm-tree
xxhash==3.5.0
xxhash==3.6.0
# via datasets
yarl==1.20.1
yarl==1.22.0
# via aiohttp
zipp==3.23.0
# via importlib-metadata
# via
# etils
# importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools
+4 -4
View File
@@ -1,9 +1,9 @@
# requirements.in
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
-e .[all]
+2 -1
View File
@@ -16,7 +16,7 @@ import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
import torch
@@ -268,6 +268,7 @@ class RemotePolicyConfig:
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
rename_map: dict[str, str] = field(default_factory=dict)
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
+4 -1
View File
@@ -159,7 +159,10 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
self.preprocessor, self.postprocessor = make_pre_post_processors(
self.policy.config,
pretrained_path=policy_specs.pretrained_name_or_path,
preprocessor_overrides={"device_processor": device_override},
preprocessor_overrides={
"device_processor": device_override,
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
},
postprocessor_overrides={"device_processor": device_override},
)
+3 -3
View File
@@ -17,7 +17,7 @@
import abc
from typing import Any
import numpy as np
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
from .configs import CameraConfig, ColorMode
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""Capture and return a single frame from the camera.
Args:
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
"""Asynchronously capture and return a single frame from the camera.
Args:
+3 -3
View File
@@ -18,7 +18,7 @@ import abc
from dataclasses import dataclass
from enum import Enum
import draccus
import draccus # type: ignore # TODO: add type stubs for draccus
class ColorMode(str, Enum):
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
fps: int | None = None
width: int | None = None
height: int | None = None
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
return str(self.get_choice_name(self.__class__))
+2
View File
@@ -14,3 +14,5 @@
from .camera_opencv import OpenCVCamera
from .configuration_opencv import OpenCVCameraConfig
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
+71 -15
View File
@@ -25,11 +25,12 @@ from pathlib import Path
from threading import Event, Lock, Thread
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@@ -121,7 +122,7 @@ class OpenCVCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -140,7 +141,7 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
def connect(self, warmup: bool = True):
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
@@ -180,12 +181,14 @@ class OpenCVCamera(Camera):
def _configure_capture_settings(self) -> None:
"""
Applies the specified FPS, width, and height settings to the connected camera.
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
This method attempts to set the camera properties via OpenCV. It checks if
the camera successfully applied the settings and raises an error if not.
FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
Args:
fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
fps: The desired frames per second. If None, the setting is skipped.
width: The desired capture width. If None, the setting is skipped.
height: The desired capture height. If None, the setting is skipped.
@@ -199,10 +202,11 @@ class OpenCVCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
self._validate_fourcc()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
@@ -216,18 +220,56 @@ class OpenCVCamera(Camera):
else:
self._validate_width_and_height()
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
def _validate_fps(self) -> None:
"""Validates and sets the camera's frames per second (FPS)."""
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.fps is None:
raise ValueError(f"{self} FPS is not set")
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
# Use math.isclose for robust float comparison
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
def _validate_fourcc(self) -> None:
"""Validates and sets the camera's FOURCC code."""
fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
# Convert actual FOURCC code back to string for comparison
actual_fourcc_code_int = int(actual_fourcc_code)
actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
if not success or actual_fourcc != self.config.fourcc:
logger.warning(
f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
f"Continuing with default format."
)
def _validate_width_and_height(self) -> None:
"""Validates and sets the camera's frame capture width and height."""
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.capture_width is None or self.capture_height is None:
raise ValueError(f"{self} capture_width or capture_height is not set")
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
@@ -258,11 +300,12 @@ class OpenCVCamera(Camera):
"""
found_cameras_info = []
targets_to_scan: list[str | int]
if platform.system() == "Linux":
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
targets_to_scan = [str(p) for p in possible_paths]
else:
targets_to_scan = list(range(MAX_OPENCV_INDEX))
targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
for target in targets_to_scan:
camera = cv2.VideoCapture(target)
@@ -271,6 +314,12 @@ class OpenCVCamera(Camera):
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
default_fps = camera.get(cv2.CAP_PROP_FPS)
default_format = camera.get(cv2.CAP_PROP_FORMAT)
# Get FOURCC code and convert to string
default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
default_fourcc_code_int = int(default_fourcc_code)
default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
camera_info = {
"name": f"OpenCV Camera @ {target}",
"type": "OpenCV",
@@ -278,6 +327,7 @@ class OpenCVCamera(Camera):
"backend_api": camera.getBackendName(),
"default_stream_profile": {
"format": default_format,
"fourcc": default_fourcc,
"width": default_width,
"height": default_height,
"fps": default_fps,
@@ -289,7 +339,7 @@ class OpenCVCamera(Camera):
return found_cameras_info
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -317,6 +367,9 @@ class OpenCVCamera(Camera):
start_time = time.perf_counter()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
ret, frame = self.videocapture.read()
if not ret or frame is None:
@@ -329,7 +382,7 @@ class OpenCVCamera(Camera):
return processed_frame
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
@@ -372,7 +425,7 @@ class OpenCVCamera(Camera):
return processed_image
def _read_loop(self):
def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
@@ -383,6 +436,9 @@ class OpenCVCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -419,7 +475,7 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -462,7 +518,7 @@ class OpenCVCamera(Camera):
return frame
def disconnect(self):
def disconnect(self) -> None:
"""
Disconnects from the camera and cleans up resources.
@@ -17,6 +17,8 @@ from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
@CameraConfig.register_subclass("opencv")
@dataclass
@@ -33,8 +35,9 @@ class OpenCVCameraConfig(CameraConfig):
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
# Advanced configurations
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
# Advanced configurations with FOURCC format
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
```
Attributes:
@@ -46,17 +49,21 @@ class OpenCVCameraConfig(CameraConfig):
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
- FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
- Setting FOURCC can help achieve higher frame rates on some cameras.
"""
index_or_path: int | Path
color_mode: ColorMode = ColorMode.RGB
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
def __post_init__(self):
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
@@ -71,3 +78,8 @@ class OpenCVCameraConfig(CameraConfig):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(
f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
)
@@ -16,6 +16,8 @@ from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
@@ -62,7 +64,7 @@ class Reachy2CameraConfig(CameraConfig):
port: int = 50065
# use_depth: bool = False
def __post_init__(self):
def __post_init__(self) -> None:
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
@@ -23,13 +23,17 @@ import time
from threading import Event, Lock, Thread
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
CameraManager,
)
from lerobot.utils.errors import DeviceNotConnectedError
@@ -73,7 +77,7 @@ class Reachy2Camera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -83,13 +87,17 @@ class Reachy2Camera(Camera):
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
return bool(
self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
)
elif self.config.name == "depth":
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
return bool(
self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
)
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
def connect(self, warmup: bool = True):
def connect(self, warmup: bool = True) -> None:
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
@@ -131,7 +139,7 @@ class Reachy2Camera(Camera):
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -152,7 +160,7 @@ class Reachy2Camera(Camera):
start_time = time.perf_counter()
frame = None
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -179,7 +187,7 @@ class Reachy2Camera(Camera):
return frame
def _read_loop(self):
def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
@@ -190,6 +198,9 @@ class Reachy2Camera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -226,7 +237,7 @@ class Reachy2Camera(Camera):
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -269,7 +280,7 @@ class Reachy2Camera(Camera):
return frame
def disconnect(self):
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -21,11 +21,12 @@ import time
from threading import Event, Lock, Thread
from typing import Any
import cv2
import numpy as np
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
try:
import pyrealsense2 as rs
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
except Exception as e:
logging.info(f"Could not import realsense: {e}")
@@ -132,7 +133,7 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: np.ndarray | None = None
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -150,7 +151,7 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
def connect(self, warmup: bool = True):
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -264,7 +265,7 @@ class RealSenseCamera(Camera):
serial_number = str(found_devices[0]["serial_number"])
return serial_number
def _configure_rs_pipeline_config(self, rs_config):
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
"""Creates and configures the RealSense pipeline configuration object."""
rs.config.enable_device(rs_config, self.serial_number)
@@ -293,6 +294,9 @@ class RealSenseCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
if self.fps is None:
@@ -308,7 +312,7 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -336,6 +340,9 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -351,7 +358,7 @@ class RealSenseCamera(Camera):
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
@@ -376,6 +383,9 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -392,8 +402,8 @@ class RealSenseCamera(Camera):
return color_image_processed
def _postprocess_image(
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
) -> np.ndarray:
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
) -> NDArray[Any]:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
@@ -438,7 +448,7 @@ class RealSenseCamera(Camera):
return processed_image
def _read_loop(self):
def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
@@ -449,6 +459,9 @@ class RealSenseCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
@@ -474,7 +487,7 @@ class RealSenseCamera(Camera):
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self):
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
@@ -486,7 +499,7 @@ class RealSenseCamera(Camera):
self.stop_event = None
# NOTE(Steven): Missing implementation for depth for now
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -529,7 +542,7 @@ class RealSenseCamera(Camera):
return frame
def disconnect(self):
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
def __post_init__(self):
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
+6 -6
View File
@@ -53,14 +53,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
import cv2
import cv2 # type: ignore # TODO: add type stubs for OpenCV
if rotation == Cv2Rotation.ROTATE_90:
return cv2.ROTATE_90_CLOCKWISE
return int(cv2.ROTATE_90_CLOCKWISE)
elif rotation == Cv2Rotation.ROTATE_180:
return cv2.ROTATE_180
return int(cv2.ROTATE_180)
elif rotation == Cv2Rotation.ROTATE_270:
return cv2.ROTATE_90_COUNTERCLOCKWISE
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
@@ -69,8 +69,8 @@ def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return cv2.CAP_ANY
return int(cv2.CAP_ANY)
+1 -1
View File
@@ -57,7 +57,7 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
def __post_init__(self):
def __post_init__(self) -> None:
if self.batch_size > self.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
+14 -6
View File
@@ -13,8 +13,8 @@
# limitations under the License.
import datetime as dt
import logging
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from lerobot import envs, policies # noqa: F401
@@ -22,6 +22,8 @@ from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
logger = getLogger(__name__)
@dataclass
class EvalPipelineConfig:
@@ -34,25 +36,31 @@ class EvalPipelineConfig:
output_dir: Path | None = None
job_name: str | None = None
seed: int | None = 1000
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
def __post_init__(self) -> None:
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
self.policy.pretrained_path = Path(policy_path)
else:
logging.warning(
logger.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type}"
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
self.job_name = (
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
)
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
if not self.output_dir:
now = dt.datetime.now()
+17 -9
View File
@@ -16,14 +16,19 @@ import inspect
import pkgutil
import sys
from argparse import ArgumentError
from collections.abc import Sequence
from collections.abc import Callable, Iterable, Sequence
from functools import wraps
from pathlib import Path
from pkgutil import ModuleInfo
from types import ModuleType
from typing import Any, TypeVar, cast
import draccus
from lerobot.utils.utils import has_method
F = TypeVar("F", bound=Callable[..., object])
PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
@@ -60,7 +65,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
return None
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
"""Parse plugin-related arguments from command-line arguments.
This function extracts arguments from command-line arguments that match a specified suffix pattern.
@@ -127,7 +132,7 @@ def load_plugin(plugin_path: str) -> None:
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
) from e
def iter_namespace(ns_pkg):
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
try:
@@ -148,6 +153,8 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
if args is None:
return []
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
@@ -171,7 +178,8 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
if isinstance(fields_to_filter, str):
fields_to_filter = [fields_to_filter]
filtered_args = args
filtered_args = [] if args is None else list(args)
for field in fields_to_filter:
if get_path_arg(field, args):
if get_type_arg(field, args):
@@ -184,7 +192,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args
def wrap(config_path: Path | None = None):
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
"""
HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
@@ -195,9 +203,9 @@ def wrap(config_path: Path | None = None):
from the CLI '.type' arguments
"""
def wrapper_outer(fn):
def wrapper_outer(fn: F) -> F:
@wraps(fn)
def wrapper_inner(*args, **kwargs):
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
argspec = inspect.getfullargspec(fn)
argtype = argspec.annotations[argspec.args[0]]
if len(args) > 0 and type(args[0]) is argtype:
@@ -225,6 +233,6 @@ def wrap(config_path: Path | None = None):
response = fn(cfg, *args, **kwargs)
return response
return wrapper_inner
return cast(F, wrapper_inner)
return wrapper_outer
return cast(Callable[[F], F], wrapper_outer)
+24 -17
View File
@@ -14,12 +14,12 @@
import abc
import builtins
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from typing import TypeVar
from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
@@ -34,10 +34,11 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
T = TypeVar("T", bound="PreTrainedConfig")
logger = getLogger(__name__)
@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
"""
Base configuration class for policy models.
@@ -57,12 +58,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
push_to_hub: bool = True
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
@@ -73,38 +74,41 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: str | None = None
pretrained_path: Path | None = None
def __post_init__(self):
def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
logger.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@property
@abc.abstractmethod
def observation_delta_indices(self) -> list | None:
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None:
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@property
@abc.abstractmethod
def reward_delta_indices(self) -> list | None:
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
raise NotImplementedError
@abc.abstractmethod
@@ -154,13 +158,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**policy_kwargs,
**policy_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -168,7 +172,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
@@ -194,6 +198,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with open(config_file) as f:
config = json.load(f)
+25 -14
View File
@@ -16,6 +16,7 @@ import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import draccus
from huggingface_hub import hf_hub_download
@@ -63,18 +64,18 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
checkpoint_path: Path | None = field(init=False, default=None)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
self.checkpoint_path = None
def validate(self):
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
self.policy.pretrained_path = Path(policy_path)
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
@@ -82,14 +83,22 @@ class TrainPipelineConfig(HubMixin):
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
self.checkpoint_path = policy_dir.parent
if self.policy is None:
raise ValueError(
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
)
if not self.job_name:
if self.env is None:
@@ -126,8 +135,8 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict:
return draccus.encode(self)
def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
@@ -139,13 +148,13 @@ class TrainPipelineConfig(HubMixin):
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs,
**kwargs: Any,
) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -181,4 +190,6 @@ class TrainPipelineConfig(HubMixin):
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
+1 -1
View File
@@ -42,4 +42,4 @@ class NormalizationMode(str, Enum):
@dataclass
class PolicyFeature:
type: FeatureType
shape: tuple
shape: tuple[int, ...]
+14 -18
View File
@@ -39,6 +39,7 @@ from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -962,28 +963,23 @@ def _copy_data_with_feature_changes(
remove_features: list[str] | None = None,
) -> None:
"""Copy data while adding or removing features."""
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.meta.root)
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
# Map file paths to episode indices to extract chunk/file indices
file_to_episodes: dict[Path, set[int]] = {}
for ep_idx in range(dataset.meta.total_episodes):
file_path = dataset.meta.get_data_file_path(ep_idx)
if file_path not in file_to_episodes:
file_to_episodes[file_path] = set()
file_to_episodes[file_path].add(ep_idx)
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
frame_idx = 0
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Get chunk_idx and file_idx from the source file's first episode
episodes_in_file = file_to_episodes[src_path]
first_ep_idx = min(episodes_in_file)
src_ep = dataset.meta.episodes[first_ep_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
relative_path = src_path.relative_to(dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
if remove_features:
df = df.drop(columns=remove_features, errors="ignore")
@@ -1009,7 +1005,7 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_slice
frame_idx = end_idx
# Write using the preserved chunk_idx and file_idx from source
# Write using the same chunk/file structure as source
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
+44 -16
View File
@@ -430,9 +430,7 @@ class LeRobotDatasetMetadata:
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(
video_key=video_key, chunk_index=0, file_index=0
)
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
@@ -686,6 +684,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = None
self.writer = None
self.latest_episode = None
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
self.root.mkdir(exist_ok=True, parents=True)
@@ -708,7 +707,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
@@ -830,31 +830,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes."""
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
# Get available episode indices from cached dataset
available_episodes = {
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
for ep_idx in self.hf_dataset["episode_index"]
for ep_idx in self.hf_dataset.unique("episode_index")
}
# Determine requested episodes
if self.episodes is None:
# Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes))
else:
# Requesting specific episodes
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
return requested_episodes.issubset(available_episodes)
if not requested_episodes.issubset(available_episodes):
return False
# Check if all required video files exist
if len(self.meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
return True
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
@@ -929,11 +938,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset[q_idx][key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
"""
Query dataset for indices across keys, skipping video keys.
Tries column-first [key][indices] for speed, falls back to row-first.
Args:
query_indices: Dict mapping keys to index lists to retrieve
Returns:
Dict with stacked tensors of queried data (video keys excluded)
"""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
try:
result[key] = torch.stack(self.hf_dataset[key][q_idx])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[q_idx][key])
return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -1231,6 +1255,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
global_frame_index = 0
self._current_file_start_frame = 0
# However, if the episodes already exists
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
@@ -1242,6 +1267,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
self._current_file_start_frame = global_frame_index
else:
# Retrieve information from the latest parquet file
latest_ep = self.latest_episode
@@ -1252,7 +1278,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_file_size_in_mb(latest_path)
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
frames_in_current_file = global_frame_index - self._current_file_start_frame
av_size_per_frame = (
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
)
@@ -1266,6 +1292,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
self._close_writer()
self._writer_closed_for_reading = False
self._current_file_start_frame = global_frame_index
ep_dict["data/chunk_index"] = chunk_idx
ep_dict["data/file_index"] = file_idx
@@ -1472,6 +1499,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None
obj._current_file_start_frame = None
# Initialize tracking for incremental recording
obj._lazy_loading = False
obj._recorded_frames = 0
+7
View File
@@ -206,6 +206,11 @@ class ImageTransformsConfig:
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
"affine": ImageTransformConfig(
weight=1.0,
type="RandomAffine",
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
),
}
)
@@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig):
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
+18 -4
View File
@@ -28,6 +28,7 @@ import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
@@ -98,7 +98,7 @@ OLD
videos/chunk-000/CAMERA/episode_000000.mp4
NEW
videos/chunk-000/file_000.mp4
videos/CAMERA/chunk-000/file_000.mp4
-------------------------
OLD
episodes.jsonl
+8 -7
View File
@@ -342,8 +342,8 @@ def encode_video_frames(
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
dummy_image = Image.open(input_list[0])
width, height = dummy_image.size
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
# Define video codec options
video_options = {}
@@ -373,11 +373,12 @@ def encode_video_frames(
# Loop through input frames and encode them
for input_data in input_list:
input_image = Image.open(input_data).convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
with Image.open(input_data) as input_image:
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
# Flush the encoder
packet = output_stream.encode()
+50 -2
View File
@@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
def package_name(self) -> str:
"""Package name to import if environment not found in gym registry"""
return f"gym_{self.type}"
@property
def gym_id(self) -> str:
"""ID string used in gym.make() to instantiate the environment"""
return f"{self.package_name}/{self.task}"
@property
@abc.abstractmethod
def gym_kwargs(self) -> dict:
@@ -236,7 +246,14 @@ class LiberoEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"agent_pos": OBS_STATE,
"robot_state/eef/pos": f"{OBS_STATE}.eef_pos",
"robot_state/eef/quat": f"{OBS_STATE}.eef_quat",
"robot_state/eef/mat": f"{OBS_STATE}.eef_mat",
"robot_state/eef/axisangle": f"{OBS_STATE}.eef_axisangle",
"robot_state/gripper/qpos": f"{OBS_STATE}.gripper_qpos",
"robot_state/gripper/qvel": f"{OBS_STATE}.gripper_qvel",
"robot_state/joints/pos": f"{OBS_STATE}.joint_pos",
"robot_state/joints/vel": f"{OBS_STATE}.joint_vel",
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
}
@@ -251,13 +268,44 @@ class LiberoEnv(EnvConfig):
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features["robot_state/eef/pos"] = PolicyFeature(
type=FeatureType.STATE,
shape=(3,),
)
self.features["robot_state/eef/quat"] = PolicyFeature(
type=FeatureType.STATE,
shape=(4,),
)
self.features["robot_state/eef/mat"] = PolicyFeature(
type=FeatureType.STATE,
shape=(3, 3),
)
self.features["robot_state/eef/axisangle"] = PolicyFeature(
type=FeatureType.STATE,
shape=(3,),
)
self.features["robot_state/gripper/qpos"] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features["robot_state/gripper/qvel"] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features["robot_state/joints/pos"] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
self.features["robot_state/joints/vel"] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
+83 -13
View File
@@ -14,10 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -31,16 +36,59 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env(
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config.
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
cfg (EnvConfig): the config of the environment to instantiate.
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor (does nothing)
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
use_async_envs: bool = False,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config or Hub reference.
Args:
cfg (EnvConfig | str): Either an `EnvConfig` object describing the environment to build locally,
or a Hugging Face Hub repository identifier (e.g. `"username/repo"`). In the latter case,
the repo must include a Python file (usually `env.py`).
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
Default False must be set to True to import/exec hub `env.py`.
Raises:
ValueError: if n_envs < 1
@@ -53,6 +101,21 @@ def make_env(
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
"""
# if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
# simplified: only support hub-provided `make_env`
if isinstance(cfg, str):
# _download_hub_file will raise the same RuntimeError if trust_remote_code is False
repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
# import and surface clear import errors
module = _import_hub_module(local_file, repo_id)
# call the hub-provided make_env
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
# normalize the return into {suite: {task_id: vec_env}}
return _normalize_hub_result(raw_result)
if n_envs < 1:
raise ValueError("`n_envs` must be at least 1")
@@ -84,17 +147,24 @@ def make_env(
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.task}"
if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
def _make_one():
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
+74 -20
View File
@@ -175,11 +175,39 @@ class LiberoEnv(gym.Env):
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"agent_pos": spaces.Box(
low=AGENT_POS_LOW,
high=AGENT_POS_HIGH,
shape=(OBS_STATE_DIM,),
dtype=np.float64,
"robot_state": spaces.Dict(
{
"eef": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
"quat": spaces.Box(
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
),
"mat": spaces.Box(
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
),
"axisangle": spaces.Box(
low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64
),
}
),
"gripper": spaces.Dict(
{
"qpos": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
"qvel": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
}
),
"joints": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
}
),
}
),
}
)
@@ -191,6 +219,7 @@ class LiberoEnv(gym.Env):
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
@@ -212,23 +241,50 @@ class LiberoEnv(gym.Env):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
eef_pos = raw_obs.get("robot0_eef_pos")
eef_quat = raw_obs.get("robot0_eef_quat")
# rotation matrix from controller
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
eef_axisangle = quat2axisangle(eef_quat) if eef_quat is not None else None
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
joint_pos = raw_obs.get("robot0_joint_pos")
joint_vel = raw_obs.get("robot0_joint_vel")
obs = {
"pixels": images,
"robot_state": {
"eef": {
"pos": eef_pos, # (3,)
"quat": eef_quat, # (4,)
"mat": eef_mat, # (3, 3)
"axisangle": eef_axisangle, # (3)
},
"gripper": {
"qpos": gripper_qpos, # (2,)
"qvel": gripper_qvel, # (2,)
},
"joints": {
"pos": joint_pos, # (7,)
"vel": joint_vel, # (7,)
},
},
}
if self.obs_type == "pixels":
return {"pixels": images.copy()}
if self.obs_type == "pixels_agent_pos":
return {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
# Validate required fields are present
if eef_pos is None or eef_quat is None or gripper_qpos is None:
raise ValueError(
f"Missing required robot state fields in raw observation. "
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
f"gripper_qpos={gripper_qpos is not None}"
)
return obs
raise NotImplementedError(
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
@@ -355,12 +411,10 @@ def create_libero_envs(
print(f"Restricting to task_ids={task_ids_filter}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
+152 -6
View File
@@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import os
import warnings
from collections.abc import Mapping, Sequence
from functools import singledispatch
@@ -22,14 +24,27 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
def _convert_nested_dict(d):
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _convert_nested_dict(v)
elif isinstance(v, np.ndarray):
result[k] = torch.from_numpy(v)
else:
result[k] = v
return result
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
@@ -75,12 +90,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[OBS_ENV_STATE] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "agent_pos" in observations:
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
return return_observations
@@ -195,3 +212,132 @@ def _(envs: Sequence) -> None:
@close_envs.register
def _(env: gym.Env) -> None:
_close_single_env(env)
# helper to safely load a python file as a module
def _load_module_from_path(path: str, module_name: str | None = None):
module_name = module_name or f"hub_env_{os.path.basename(path).replace('.', '_')}"
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None:
raise ImportError(f"Could not load module spec for {module_name} from {path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
# helper to parse hub string (supports "user/repo", "user/repo@rev", optional path)
# examples:
# "user/repo" -> will look for env.py at repo root
# "user/repo@main:envs/my_env.py" -> explicit revision and path
def _parse_hub_url(hub_uri: str):
# very small parser: [repo_id][@revision][:path]
# repo_id is required (user/repo or org/repo)
revision = None
file_path = "env.py"
if "@" in hub_uri:
repo_and_rev, *rest = hub_uri.split(":", 1)
repo_id, rev = repo_and_rev.split("@", 1)
revision = rev
if rest:
file_path = rest[0]
else:
repo_id, *rest = hub_uri.split(":", 1)
if rest:
file_path = rest[0]
return repo_id, revision, file_path
def _download_hub_file(
cfg_str: str,
trust_remote_code: bool,
hub_cache_dir: str | None,
) -> tuple[str, str, str, str]:
"""
Parse `cfg_str` (hub URL), enforce `trust_remote_code`, and return
(repo_id, file_path, local_file, revision).
"""
if not trust_remote_code:
raise RuntimeError(
f"Refusing to execute remote code from the Hub for '{cfg_str}'. "
"Executing hub env modules runs arbitrary Python code from third-party repositories. "
"If you trust this repo and understand the risks, call `make_env(..., trust_remote_code=True)` "
"and prefer pinning to a specific revision: 'user/repo@<commit-hash>:env.py'."
)
repo_id, revision, file_path = _parse_hub_url(cfg_str)
try:
local_file = hf_hub_download(
repo_id=repo_id, filename=file_path, revision=revision, cache_dir=hub_cache_dir
)
except Exception as e:
# fallback to snapshot download
snapshot_dir = snapshot_download(repo_id=repo_id, revision=revision, cache_dir=hub_cache_dir)
local_file = os.path.join(snapshot_dir, file_path)
if not os.path.exists(local_file):
raise FileNotFoundError(
f"Could not find {file_path} in repository {repo_id}@{revision or 'main'}"
) from e
return repo_id, file_path, local_file, revision
def _import_hub_module(local_file: str, repo_id: str) -> Any:
"""
Import the downloaded file as a module and surface helpful import error messages.
"""
module_name = f"hub_env_{repo_id.replace('/', '_')}"
try:
module = _load_module_from_path(local_file, module_name=module_name)
except ModuleNotFoundError as e:
missing = getattr(e, "name", None) or str(e)
raise ModuleNotFoundError(
f"Hub env '{repo_id}:{os.path.basename(local_file)}' failed to import because the dependency "
f"'{missing}' is not installed locally.\n\n"
) from e
except ImportError as e:
raise ImportError(
f"Failed to load hub env module '{repo_id}:{os.path.basename(local_file)}'. Import error: {e}\n\n"
) from e
return module
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
"""
Ensure module exposes make_env and call it.
"""
if not hasattr(module, "make_env"):
raise AttributeError(
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
)
entry_fn = module.make_env
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""
Normalize possible return types from hub `make_env` into the mapping:
{ suite_name: { task_id: vector_env } }
Accepts:
- dict (assumed already correct)
- gym.vector.VectorEnv
- gym.Env (will be wrapped into SyncVectorEnv)
"""
if isinstance(result, dict):
return result
# VectorEnv: use its spec.id if available
if isinstance(result, gym.vector.VectorEnv):
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
return {suite_name: {0: result}}
# Single Env: wrap into SyncVectorEnv
if isinstance(result, gym.Env):
vec = gym.vector.SyncVectorEnv([lambda: result])
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
return {suite_name: {0: vec}}
raise ValueError(
"Hub `make_env` must return either a mapping {suite: {task_id: vec_env}}, "
"a gym.vector.VectorEnv, or a single gym.Env."
)
+12 -8
View File
@@ -22,18 +22,18 @@ class RobotKinematics:
self,
urdf_path: str,
target_frame_name: str = "gripper_frame_link",
joint_names: list[str] = None,
joint_names: list[str] | None = None,
):
"""
Initialize placo-based kinematics solver.
Args:
urdf_path: Path to the robot URDF file
target_frame_name: Name of the end-effector frame in the URDF
joint_names: List of joint names to use for the kinematics solver
urdf_path (str): Path to the robot URDF file
target_frame_name (str): Name of the end-effector frame in the URDF
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
try:
import placo
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
except ImportError as e:
raise ImportError(
"placo is required for RobotKinematics. "
@@ -52,7 +52,7 @@ class RobotKinematics:
# Initialize frame task for IK
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
def forward_kinematics(self, joint_pos_deg):
def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray:
"""
Compute forward kinematics for given joint configuration given the target frame name in the constructor.
@@ -77,8 +77,12 @@ class RobotKinematics:
return self.robot.get_T_world_frame(self.target_frame_name)
def inverse_kinematics(
self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
):
self,
current_joint_pos: np.ndarray,
desired_ee_pose: np.ndarray,
position_weight: float = 1.0,
orientation_weight: float = 0.01,
) -> np.ndarray:
"""
Compute inverse kinematics using placo solver.
+1 -1
View File
@@ -60,7 +60,7 @@ class OperatingMode(Enum):
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
# conveyor systems or a system that requires an additional reduction gear. Note that Max Position
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
EXTENDED_POSITION = 4
+4
View File
@@ -206,8 +206,12 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11,
"Goal_Position": 15,
"Goal_Velocity": 15,
"Goal_Speed": 15,
"Present_Position": 15,
"Present_Velocity": 15,
"Present_Speed": 15,
}
MODEL_ENCODING_TABLE = {
+2
View File
@@ -14,6 +14,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
@@ -29,4 +30,5 @@ __all__ = [
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
]
+2 -2
View File
@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""
+48 -2
View File
@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -37,6 +38,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
@@ -101,6 +103,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -142,6 +148,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -199,6 +207,27 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
# GROOT handles normalization in groot_pack_inputs_v3 step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features["action"].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -293,6 +322,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
processors = make_groot_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
@@ -303,6 +340,7 @@ def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
rename_map: dict[str, str] | None = None,
) -> PreTrainedPolicy:
"""
Instantiate a policy model.
@@ -319,6 +357,8 @@ def make_policy(
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
rename_map: Optional mapping of dataset or environment feature keys to match
expected policy feature names (e.g., `"left"` `"camera1"`).
Returns:
An instantiated and device-placed policy model.
@@ -360,8 +400,10 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
if not cfg.output_features:
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
@@ -378,4 +420,8 @@ def make_policy(
# policy = torch.compile(policy, mode="reduce-overhead")
if not rename_map:
validate_visual_features_consistency(cfg, features)
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
return policy
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_groot_README.md
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 Nvidia and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,5 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_stretch3 import Stretch3RobotConfig
from .robot_stretch3 import Stretch3Robot
from .configuration_groot import GrootConfig
from .modeling_groot import GrootPolicy
from .processor_groot import make_groot_pre_post_processors
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
@@ -1,18 +1,14 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_viperx import ViperXConfig
from .viperx import ViperX
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
def swish(x):
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
"""
Produces a sinusoidal encoding of shape (B, T, w)
given timesteps of shape (B, T).
"""
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps):
# timesteps: shape (B, T)
# We'll compute sin/cos frequencies across dim T
timesteps = timesteps.float() # ensure float
b, t = timesteps.shape
device = timesteps.device
half_dim = self.embedding_dim // 2
# typical log space frequencies for sinusoidal encoding
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
# Expand timesteps to (B, T, 1) then multiply
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
@@ -0,0 +1,370 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import (
SinusoidalPositionalEmbedding,
TimestepEmbedding,
Timesteps,
)
from torch import nn
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(
self,
embedding_dim: int,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: torch.Tensor | None = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: int | None = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: str | None = None,
num_positional_embeddings: int | None = None,
ff_inner_dim: int | None = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.norm_type = norm_type
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
if final_dropout:
self.final_dropout = nn.Dropout(dropout)
else:
self.final_dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
temb: torch.LongTensor | None = None,
) -> torch.Tensor:
# 0. Self-Attention
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, temb)
else:
norm_hidden_states = self.norm1(hidden_states)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
output_dim: int = 26,
num_layers: int = 12,
dropout: float = 0.1,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: int | None = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
max_num_positional_embeddings: int = 512,
compute_dtype=torch.float32,
final_dropout: bool = True,
positional_embeddings: str | None = "sinusoidal",
interleave_self_attention=False,
cross_attention_dim: int | None = None,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
# Timestep encoder
self.timestep_encoder = TimestepEncoder(
embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype
)
all_blocks = []
for idx in range(self.config.num_layers):
use_self_attn = idx % 2 == 1 and interleave_self_attention
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
all_blocks += [
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
positional_embeddings=positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=final_dropout,
cross_attention_dim=curr_cross_attention_dim,
)
]
self.transformer_blocks = nn.ModuleList(all_blocks)
# Output blocks
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
print(
"Total number of DiT parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
return_all_hidden_states: bool = False,
):
# Encode timesteps
temb = self.timestep_encoder(timestep)
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1 and self.config.interleave_self_attention:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
temb=temb,
)
all_hidden_states.append(hidden_states)
# Output processing
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
else:
return self.proj_out_2(hidden_states)
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
output_dim: int = 26,
num_layers: int = 12,
dropout: float = 0.1,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: int | None = 1000,
upcast_attention: bool = False,
max_num_positional_embeddings: int = 512,
compute_dtype=torch.float32,
final_dropout: bool = True,
positional_embeddings: str | None = "sinusoidal",
interleave_self_attention=False,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
positional_embeddings=positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=final_dropout,
)
for _ in range(self.config.num_layers)
]
)
print(
"Total number of SelfAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
return_all_hidden_states: bool = False,
):
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for _idx, block in enumerate(self.transformer_blocks):
hidden_states = block(hidden_states)
all_hidden_states.append(hidden_states)
if return_all_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states
@@ -0,0 +1,406 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
else:
PretrainedConfig = object
BatchFeature = None
from lerobot.policies.groot.action_head.action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
from .cross_attention_dit import DiT, SelfAttentionTransformer
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size, num_embodiments):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
b, t, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == b:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, t)
else:
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
@dataclass
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
backbone_embedding_dim: int = field(
default=1536, metadata={"help": "Backbone embedding channel dimension."}
)
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
num_timestep_buckets: int = field(
default=1000, metadata={"help": "Number of timestep discretization buckets."}
)
num_inference_timesteps: int = field(
default=None,
metadata={"help": "Number of inference steps for noise diffusion."},
)
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
tune_diffusion_model: bool = field(
default=True, metadata={"help": "Whether to tune the diffusion model."}
)
load_pretrained_det_decode_layer_path: str = field(
default=None, metadata={"help": "Path to pretrained detection model."}
)
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
freeze_decode_layer: bool = field(default=False)
expand_batch: int = field(default=None)
use_vlln: bool = field(default=True)
vl_self_attention_cfg: dict = field(default=None)
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
class FlowmatchingActionHead(nn.Module):
config_class = FlowmatchingActionHeadConfig
supports_gradient_checkpointing = True
def __init__(
self,
config: FlowmatchingActionHeadConfig,
):
super().__init__()
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
self.model = DiT(**config.diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=config.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
self.vl_self_attention = (
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
)
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
for p in self.parameters():
p.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
print(f"Tune action head projector: {self.tune_projector}")
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_projector and not tune_diffusion_model:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Action head trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No action head trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = backbone_output["backbone_features"]
backbone_features = self.vlln(backbone_features)
backbone_features = self.vl_self_attention(backbone_features)
backbone_output["backbone_features"] = backbone_features
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
# Set frozen modules to eval
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
if self.config.expand_batch is not None:
for k, v in backbone_output.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
backbone_output[k] = expanded
for k, v in action_input.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
action_input[k] = expanded
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
device = vl_embs.device
# Get embodiment ID.
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Embed noised action trajectory.
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
vl_attn_mask = backbone_output.backbone_attention_mask
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=vl_attn_mask,
timestep=t_discretized,
return_all_hidden_states=False, # NOTE (YL): not using flare now
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
# Slice out only the action portion of pred and target.
action_mask = action_input.action_mask
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = loss.sum() / action_mask.sum()
output_dict = {
"loss": loss,
}
return BatchFeature(data=output_dict)
@torch.no_grad()
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Set initial actions as the sampled noise.
batch_size = vl_embs.shape[0]
device = vl_embs.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=vl_embs.dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# Run denoising steps.
for t in range(num_steps):
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# Embed noised action trajectory.
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
# Run model forward.
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -self.action_horizon :]
# Update actions using euler integration.
actions = actions + dt * pred_velocity
return BatchFeature(data={"action_pred": actions})
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype
@@ -0,0 +1,201 @@
#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("groot")
@dataclass
class GrootConfig(PreTrainedConfig):
"""Configuration for Groot policy wrapper."""
# Basic policy settings
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
# Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 64
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 32
# Normalization (start with identity, adjust as needed)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Groot-specific model parameters (from groot_finetune_script.py)
# Path or HuggingFace model ID for the base Groot model
base_model_path: str = "nvidia/GR00T-N1.5-3B"
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
# Fine-tuning control arguments
# Whether to fine-tune the llm backbone
tune_llm: bool = False
# Whether to fine-tune the vision tower
tune_visual: bool = False
# Whether to fine-tune the projector
tune_projector: bool = True
# Whether to fine-tune the diffusion model
tune_diffusion_model: bool = True
# LoRA parameters (from groot_finetune_script.py)
# Rank for the LORA model. If 0, no LORA will be used.
lora_rank: int = 0
# Alpha value for the LORA model
lora_alpha: int = 16
# Dropout rate for the LORA model
lora_dropout: float = 0.1
# Whether to use the full model for LORA
lora_full_model: bool = False
# Training parameters (matching groot_finetune_script.py)
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
warmup_ratio: float = 0.05
use_bf16: bool = True
# Dataset parameters
# Video backend to use for training ('decord' or 'torchvision_av')
video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t"
save_steps: int = 1000
max_steps: int = 10000
batch_size: int = 32
dataloader_num_workers: int = 8
report_to: str = "wandb"
resume: bool = False
def __post_init__(self):
super().__post_init__()
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
)
# groot_repo_path is now optional since we ported the components
# No validation needed
def validate_features(self) -> None:
"""Validate and set up input/output features for Groot."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"Groot policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
self.input_features["observation.state"] = state_feature
else:
state_shape = self.input_features["observation.state"].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
f"Either reduce state dimension or increase max_state_dim in config."
)
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
self.output_features["action"] = action_feature
else:
action_shape = self.output_features["action"].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
f"Either reduce action dimension or increase max_action_dim in config."
)
def get_optimizer_preset(self) -> AdamWConfig:
"""Return optimizer configuration."""
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Return scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
num_decay_steps=10000, # Adjust based on training steps
peak_lr=self.optimizer_lr,
decay_lr=self.optimizer_lr * 0.1,
)
@property
def observation_delta_indices(self) -> None:
"""Return indices for delta observations (None for Groot)."""
return None
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
return list(range(min(self.chunk_size, 16)))
@property
def reward_delta_indices(self) -> None:
"""Return indices for delta rewards (None for Groot)."""
return None
@@ -0,0 +1,135 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Eagle25VLConfig(PretrainedConfig):
model_type = "eagle_2_5_vl"
is_composition = True
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
def __init__(
self,
vision_config=None,
text_config=None,
use_backbone_lora=0,
use_llm_lora=0,
pad2square=False,
select_layer=-4,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
loss_version="v1",
min_dynamic_tiles=1,
max_dynamic_tiles=6,
mlp_checkpoint=False,
initializer_range=0.02,
_attn_implementation="flash_attention_2",
_attn_implementation_autoset=False,
llm_config=None,
image_token_index=None,
use_pixel_shuffle=True,
mlp_connector_layers=2,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {"model_type": "siglip_vision_model"}
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
if text_config is None:
text_config = {"architectures": ["Qwen2ForCausalLM"]}
logger.info(
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
if vision_config["model_type"] == "siglip_vision_model":
self.vision_config = SiglipVisionConfig(**vision_config)
else:
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
if text_config["architectures"][0] == "LlamaForCausalLM":
self.text_config = LlamaConfig(**text_config)
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
self.text_config = Qwen2Config(**text_config)
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
self.text_config = Qwen3Config(**text_config)
else:
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.mlp_checkpoint = mlp_checkpoint
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.loss_version = loss_version
self.initializer_range = initializer_range
self.min_dynamic_tiles = min_dynamic_tiles
self.max_dynamic_tiles = max_dynamic_tiles
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self._attn_implementation = _attn_implementation
self._attn_implementation_autoset = _attn_implementation_autoset
self.image_token_index = image_token_index
self.use_pixel_shuffle = use_pixel_shuffle
self.mlp_connector_layers = mlp_connector_layers
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
output["use_backbone_lora"] = self.use_backbone_lora
output["use_llm_lora"] = self.use_llm_lora
output["pad2square"] = self.pad2square
output["select_layer"] = self.select_layer
output["force_image_size"] = self.force_image_size
output["downsample_ratio"] = self.downsample_ratio
output["template"] = self.template
output["dynamic_image_size"] = self.dynamic_image_size
output["use_thumbnail"] = self.use_thumbnail
output["min_dynamic_tiles"] = self.min_dynamic_tiles
output["max_dynamic_tiles"] = self.max_dynamic_tiles
output["tie_word_embeddings"] = self.tie_word_embeddings
output["_attn_implementation"] = self._attn_implementation
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
output["use_pixel_shuffle"] = self.use_pixel_shuffle
output["mlp_connector_layers"] = self.mlp_connector_layers
return output
@@ -0,0 +1,504 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from typing import Optional
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
make_flat_list_of_images,
validate_kwargs,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_v2_available,
)
from transformers.video_utils import VideoInput
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F # noqa: N812
from transformers.image_utils import pil_torch_interpolation_mapping
else:
from torchvision.transforms import functional as F # noqa: N812
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
"""Crop the given numpy array.
Args:
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
left (int): The left coordinate of the crop box.
top (int): The top coordinate of the crop box.
right (int): The right coordinate of the crop box.
bottom (int): The bottom coordinate of the crop box.
Returns:
torch.Tensor: Cropped image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
if img.ndim not in [2, 3]:
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
img_height = img.shape[1]
img_width = img.shape[2]
if top < 0 or left < 0 or bottom > img_height or right > img_width:
raise ValueError("Crop coordinates out of bounds")
if top >= bottom or left >= right:
raise ValueError("Invalid crop coordinates")
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
pad_during_tiling: bool | None
do_pad: bool | None
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method. Not used for processing videos.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 448, "width": 448}
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
max_dynamic_tiles = 12
min_dynamic_tiles = 1
use_thumbnail = True
pad_during_tiling = False
valid_kwargs = Eagle25VLFastImageProcessorKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
"""
max_dynamic_tiles (`int`, *optional*):
The maximum number of dynamic tiles to use for processing high resolution images.
min_dynamic_tiles (`int`, *optional*):
The minimum number of dynamic tiles to use for processing high resolution images.
use_thumbnail (`bool`, *optional*):
Whether to use a thumbnail for processing high resolution images.
pad_during_tiling (`bool`, *optional*):
Whether to pad the image during tiling.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
# NOTE(YL): we will overload the preprocess method to add the image_flags
# def preprocess(
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
# ) -> BatchFeature:
# return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
expected_ndims: int = 3,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
expected_ndims (`int`, *optional*, defaults to 3):
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _resize_for_patching(
self,
image: "torch.Tensor",
target_resolution: tuple,
interpolation: "F.InterpolationMode",
input_data_format: ChannelDimension,
) -> "torch.Tensor":
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image ("torch.Tensor"):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
"torch.Tensor": The resized and padded image.
"""
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
return resized_image
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
"""
previous version mainly focus on ratio.
We also consider area ratio here.
"""
best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
"""
new area > 60% of original image area is enough.
"""
factor_based_on_area_n_ratio = min(
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
return padded_image
def _get_image_patches(
self,
image: "torch.Tensor",
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: "F.InterpolationMode",
pad_during_tiling: bool,
) -> list["torch.Tensor"]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
# calculate the target width and height
target_width = tile_size * target_aspect_ratio[0]
target_height = tile_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if pad_during_tiling:
resized_image = self._resize_for_patching(
image,
(target_height, target_width),
interpolation=interpolation,
input_data_format=ChannelDimension.FIRST,
)
padded_image = self._pad_for_patching(
resized_image,
(target_height, target_width),
input_data_format=ChannelDimension.FIRST,
)
image_used_to_split = padded_image
else:
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
processed_tiles = []
for i in range(blocks):
box = (
(i % (target_width // tile_size)) * tile_size,
(i // (target_width // tile_size)) * tile_size,
((i % (target_width // tile_size)) + 1) * tile_size,
((i // (target_width // tile_size)) + 1) * tile_size,
)
# split the image
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
processed_tiles.append(split_img)
assert len(processed_tiles) == blocks
if use_thumbnail and len(processed_tiles) != 1:
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
processed_tiles.append(thumbnail_img)
return processed_tiles
def _pad_for_batching(
self,
pixel_values: list["torch.Tensor"],
) -> list["torch.Tensor"]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
for image in pixel_values
]
return pixel_values
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
do_pad: bool,
return_tensors: str | TensorType | None,
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
) -> BatchFeature:
processed_images = []
image_sizes = []
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
tile_size = crop_size.height
elif size and size.height:
tile_size = size.height
else:
tile_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
min_num=min_dynamic_tiles,
max_num=max_dynamic_tiles,
size=size_tuple,
tile_size=tile_size,
use_thumbnail=use_thumbnail,
interpolation=interpolation,
pad_during_tiling=pad_during_tiling,
)
# Group images by size for batched processing
processed_image_patches_grouped = {}
# Added for transformers >=4.53.0 compatibility
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches,
disable_grouping=disable_grouping,
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(
processed_image_patches_grouped, grouped_image_patches_index
)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes},
tensor_type=return_tensors,
)
def preprocess(
self,
images: ImageInput,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
if images is not None:
images = self._prepare_image_like_inputs(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
if videos is not None:
videos = self._prepare_image_like_inputs(
images=videos,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
# Added for transformers >=4.53.0 compatibility
resample = kwargs.pop("resample", self.resample)
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, PILImageResampling | int)
else resample
)
# Filter kwargs to only include those accepted by _preprocess
valid_preprocess_kwargs = {
"do_resize",
"size",
"max_dynamic_tiles",
"min_dynamic_tiles",
"use_thumbnail",
"pad_during_tiling",
"interpolation",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"return_tensors",
"pad_size",
"disable_grouping",
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
if images is not None:
return self._preprocess(images, **filtered_kwargs)
elif videos is not None:
return self._preprocess(videos, **filtered_kwargs)
__all__ = ["Eagle25VLImageProcessorFast"]
@@ -0,0 +1,395 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import inspect
import torch
import torch.utils.checkpoint as cp
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers.utils import add_start_docstrings, logging
from .configuration_eagle2_5_vl import Eagle25VLConfig
logger = logging.get_logger(__name__)
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
EAGLE2_5_VL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Eagle25VLConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
EAGLE2_5_VL_START_DOCSTRING,
)
class Eagle25VLPreTrainedModel(PreTrainedModel):
config_class = Eagle25VLConfig
base_model_prefix = "model"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_no_split_modules = [
"Qwen2DecoderLayer",
"LlamaDecoderLayer",
"Siglip2EncoderLayer",
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear | nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
config_class = Eagle25VLConfig
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
super().__init__(config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
if config.use_pixel_shuffle:
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
else:
self.num_image_token = int((image_size // patch_size) ** 2)
self.select_layer = config.select_layer
self.downsample_ratio = config.downsample_ratio
self.loss_version = config.loss_version
self.mlp_checkpoint = config.mlp_checkpoint
self.use_pixel_shuffle = config.use_pixel_shuffle
self.mlp_connector_layers = config.mlp_connector_layers
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
if vision_model is not None:
self.vision_model = vision_model
else:
if config.vision_config.model_type == "siglip_vision_model":
config.vision_config._attn_implementation = "flash_attention_2"
self.vision_model = SiglipVisionModel(config.vision_config)
else:
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
if language_model is not None:
self.language_model = language_model
else:
if config.text_config.architectures[0] == "LlamaForCausalLM":
self.language_model = LlamaForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
raise NotImplementedError("Phi3 is not implemented.")
# self.language_model = Phi3ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
assert config.text_config._attn_implementation == "flash_attention_2", (
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
)
self.language_model = Qwen2ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(config.text_config)
else:
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
if config.mlp_connector_layers == 2:
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size, llm_hidden_size),
)
else:
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
self.image_token_index = config.image_token_index
self.neftune_alpha = None
if config.use_backbone_lora:
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
self.use_llm_lora = config.use_llm_lora
if config.use_llm_lora:
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
self.check_forward_kwargs()
def check_forward_kwargs(self):
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
# has special handling for functions with **kwargs parameters that would affect
# how our model is processed during training and inference.
forward_params = inspect.signature(self.forward).parameters
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.out_proj",
"mlp.fc1",
"mlp.fc2",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
self.vision_model.print_trainable_parameters()
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
task_type="CAUSAL_LM",
)
self.language_model = get_peft_model(self.language_model, lora_config)
self.language_model.enable_input_require_grads()
self.language_model.print_trainable_parameters()
self.use_llm_lora = True
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
image_flags: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
num_tiles_list: list[torch.Tensor] | None = None,
) -> tuple | CausalLMOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vit_embeds = self.extract_feature(pixel_values)
if image_flags is not None:
image_flags = image_flags.view(-1)
vit_embeds = vit_embeds[image_flags == 1]
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.image_token_index
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, c)
print(
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
f"vit_embeds.shape={vit_embeds.shape}"
)
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(b, n, c)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
)
if hasattr(vit_embeds, "last_hidden_state"):
vit_embeds = vit_embeds.last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
).hidden_states[self.select_layer]
if self.use_pixel_shuffle:
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
if self.mlp_checkpoint and vit_embeds.requires_grad:
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
else:
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor | None = None,
input_ids: torch.FloatTensor | None = None,
attention_mask: torch.LongTensor | None = None,
visual_features: torch.FloatTensor | None = None,
generation_config: GenerationConfig | None = None,
output_hidden_states: bool | None = None,
image_sizes: list[tuple[int, int]] | None = None,
**generate_kwargs,
) -> torch.LongTensor:
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.config.image_token_index
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
input_embeds = input_embeds.reshape(b, n, c)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
if "use_cache" not in generate_kwargs:
generate_kwargs["use_cache"] = True
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
**generate_kwargs,
)
return outputs
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
@@ -0,0 +1,518 @@
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Eagle25VL.
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
"""
import base64
import os
import re
from io import BytesIO
import requests
import torch
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
logger = logging.get_logger(__name__)
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 256
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == "RGBA":
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
image = ele["image"] if "image" in ele else ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True, timeout=10)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = to_rgb(image_obj)
if "scale_factor" in ele:
scale_factor = ele["scale_factor"]
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
return image
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"videos_kwargs": {"max_dynamic_tiles": 1},
}
class Eagle25VLProcessor(ProcessorMixin):
r"""
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
Args:
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
num_image_tokens (`int`, *optional*):
Number of image tokens for one imagethat will be returned by vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Should be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"<image>"`):
Special token used to denote image location.
video_token (`str`, *optional*, defaults to `"<video>"`):
Special token used to denote video location.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"num_image_tokens",
"vision_feature_select_strategy",
"image_token",
"video_token",
"images_kwargs",
"videos_kwargs",
"text_kwargs",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
vision_feature_select_strategy=None,
chat_template=None,
image_token="<IMG_CONTEXT>", # nosec: B107
video_token="<IMG_CONTEXT>", # nosec: B107
tokens_per_tile=256,
image_placeholder="image",
video_placeholder="video",
image_start_token="<img>",
image_end_token="</img>",
**kwargs,
):
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.image_placeholder = image_placeholder
self.video_placeholder = video_placeholder
self.tokens_per_tile = tokens_per_tile
self.image_start_token = image_start_token
self.image_end_token = image_end_token
if "auto_map" in kwargs:
self.auto_map = kwargs["auto_map"]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def replace_media_placeholder(
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
):
num_of_images_in_this_sample = 0
num_of_videos_in_this_sample = 0
# Regular expression pattern to match formats like <image-1> or <video-2>
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
unified_frame_list = []
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
# )
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
# )
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
# "use_thumbnail", self.image_processor.use_thumbnail
# )
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
)
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
)
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
"use_thumbnail", self.image_processor.use_thumbnail
)
tile_size = self.image_processor.size.get("height", 448)
# Function to replace tags in a single text
def replace_in_text(text):
# repl callback function for each match replacement operation
def repl(match):
nonlocal unified_frame_list
nonlocal num_of_images_in_this_sample
nonlocal num_of_videos_in_this_sample
media_type = match.group(1) # 'image' or 'video'
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
# Select the corresponding path based on media type
idx_mapper = {
0: "first",
1: "second",
2: "third",
3: "fourth",
4: "fifth",
5: "sixth",
6: "seventh",
7: "eighth",
8: "ninth",
9: "tenth",
}
if media_type == "image":
image_inputs = self.image_processor(
images=[image_list[idx_in_list]],
videos=None,
**output_kwargs["images_kwargs"],
)
num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs)
num_of_images_in_this_sample += 1
elif media_type == "video":
video_inputs = self.image_processor(
images=None,
videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"],
)
num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list:
frame_timestamps = timestamps_list[idx_in_list]
else:
frame_timestamps = None
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
num_of_tiles_each_frame = [
self.get_number_tiles_based_on_image_size(
image_size,
video_min_dynamic_tiles,
video_max_dynamic_tiles,
video_use_thumbnail,
tile_size,
)
for image_size in image_sizes
]
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
)
if frame_timestamps is not None:
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
)
special_placeholder = [
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
else:
special_placeholder = [
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
if sampled_fps is not None:
special_placeholder = (
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
+ "".join(special_placeholder)
)
else:
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
special_placeholder
)
unified_frame_list.append(video_inputs)
num_of_videos_in_this_sample += 1
else:
raise ValueError(f"Unknown media type: {media_type}")
return special_placeholder
return pattern.sub(repl, text)
text = replace_in_text(text)
if len(unified_frame_list) > 0:
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
else:
pixel_values = None
image_sizes = None
return (
text,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
)
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
audio=None,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Eagle25VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text_list = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
elif isinstance(text, list) and isinstance(text[0], str):
text_list = text
if images is None:
images = []
if videos is None:
videos = []
pixel_values_list = []
image_sizes_list = []
new_sample_list = []
image_start_idx = 0
video_start_idx = 0
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
for sample in text_list:
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
(
sample,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
) = self.replace_media_placeholder(
sample,
images[image_start_idx:],
videos[video_start_idx:],
timestamps_list,
fps_list,
**output_kwargs,
)
new_sample_list.append(sample)
if pixel_values is not None:
pixel_values_list.append(pixel_values)
image_sizes_list.append(image_sizes)
image_start_idx += num_of_images_in_this_sample
video_start_idx += num_of_videos_in_this_sample
if len(pixel_values_list) > 0:
image_inputs = {
"pixel_values": torch.cat(pixel_values_list),
"image_sizes": torch.cat(image_sizes_list),
}
else:
image_inputs = {}
video_inputs = {}
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
def get_number_tiles_based_on_image_size(
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
) -> int:
"""
Get the number of tiles based on the image size.
"""
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and tiles_num > 1:
tiles_num += 1
return tiles_num
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# override to save video-config in a separate config file
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
outputs = super().save_pretrained(save_directory, **kwargs)
return outputs
# override to load video-config from a separate config file
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
return processor
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def process_vision_info(
self,
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
vision_infos = self.extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
video_timestamps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return (
image_inputs,
video_inputs,
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
)
return image_inputs, video_inputs
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
__all__ = ["Eagle25VLProcessor"]
+376
View File
@@ -0,0 +1,376 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
from lerobot.policies.groot.action_head.flow_matching_action_head import (
FlowmatchingActionHead,
FlowmatchingActionHeadConfig,
)
from lerobot.policies.groot.utils import ensure_eagle_cache_ready
from lerobot.utils.constants import HF_LEROBOT_HOME
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
class EagleBackbone(nn.Module):
def __init__(
self,
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
project_to_dim: int = 1536,
):
"""
Args:
tune_llm: whether to tune the LLM model (default: True)
tune_visual: whether to tune the visual model (default: False)
"""
super().__init__()
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
try:
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
if project_to_dim is not None:
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
else:
self.eagle_linear = torch.nn.Identity()
# needed since we don't use these layers. Also saves compute
while len(self.eagle_model.language_model.model.layers) > select_layer:
self.eagle_model.language_model.model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual)
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for p in self.parameters():
p.requires_grad = True
if not tune_llm:
self.eagle_model.language_model.requires_grad_(False)
if not tune_visual:
self.eagle_model.vision_model.requires_grad_(False)
self.eagle_model.mlp1.requires_grad_(False)
print(f"Tune backbone llm: {self.tune_llm}")
print(f"Tune backbone visual: {self.tune_visual}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_llm and not tune_visual:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Backbone trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No backbone trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if self.eagle_model.language_model and not self.tune_llm:
self.eagle_model.language_model.eval()
if self.eagle_model.vision_model and not self.tune_visual:
self.eagle_model.vision_model.eval()
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
eagle_prefix = "eagle_"
eagle_input = {
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
}
del eagle_input["image_sizes"]
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
eagle_features = eagle_output.hidden_states[self.select_layer]
eagle_features = self.eagle_linear(eagle_features)
return eagle_features, eagle_input["attention_mask"]
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
if self.training and self.tune_visual:
dummy_term = torch.tensor(
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
)
for param in self.eagle_model.vision_model.parameters():
if param.requires_grad:
dummy_term = dummy_term + 0.0 * param.sum()
eagle_embeds = eagle_embeds + dummy_term
return BatchFeature(
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
) # [B, T2, hidden_size]
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
# config
@dataclass
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
# real model
class GR00TN15(PreTrainedModel):
supports_gradient_checkpointing = True
config_class = GR00TN15Config
"""
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
here n is variable and can be e.g. time, 1 or user specified
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
"""
def __init__(
self,
config: GR00TN15Config,
local_model_path: str,
):
assert isinstance(config.backbone_cfg, dict)
assert isinstance(config.action_head_cfg, dict)
super().__init__(config)
self.local_model_path = local_model_path
self.backbone = EagleBackbone(**config.backbone_cfg)
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
self.action_head = FlowmatchingActionHead(action_head_cfg)
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
detected_error = False
error_msg = ERROR_MSG
if "action" in inputs:
action = inputs["action"]
# In inference, action may be omitted or None; validate only when it's a tensor.
if action is None:
pass # allow None during inference
elif isinstance(action, torch.Tensor):
shape_ok = (
len(action.shape) == 3
and action.shape[1] == self.action_horizon
and action.shape[2] == self.action_dim
)
if not shape_ok:
error_msg += f"\n{action.shape=}"
detected_error = True
else:
# Unexpected non-tensor type provided for action
error_msg += f"\nInvalid type for action: {type(action)}"
detected_error = True
if "video" in inputs:
video = inputs["video"]
type_ok = isinstance(video, np.ndarray)
dtype_ok = video.dtype == np.uint8
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
if not type_ok:
error_msg += f"\n{type(video)=}"
detected_error = True
if not dtype_ok:
error_msg += f"\n{video.dtype=}"
detected_error = True
if not shape_ok:
error_msg += f"\n{video.shape=}"
detected_error = True
if detected_error:
raise ValueError(error_msg)
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
fail_backbone = (
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
)
if fail_backbone:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
raise ValueError(error_msg)
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
(
LOSS_KEY in action_head_outputs and is_training
) # there might not be an action prediction during training
or (
ACTION_KEY in action_head_outputs
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
)
)
if fail_action_head:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
error_msg += f"\n{self.action_horizon=}"
error_msg += f"\n{self.action_dim=}"
raise ValueError(error_msg)
def forward(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
return action_head_outputs
def get_action(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
return action_head_outputs
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
self.validate_inputs(inputs)
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_maybe_dtype(x):
# Cast floating tensors to a memory-efficient compute dtype when requested.
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
if getattr(self, "compute_dtype", None) == "bfloat16":
return x.to(self.device, dtype=torch.bfloat16)
# Fallback: preserve previous behavior if not using bf16 compute
return x.to(self.device, dtype=self.action_head.dtype)
# Non-floating tensors: move device only
return x.to(self.device)
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
return backbone_inputs, action_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
print(f"Tune backbone vision tower: {tune_visual}")
print(f"Tune backbone LLM: {tune_llm}")
print(f"Tune action head projector: {tune_projector}")
print(f"Tune action head DiT: {tune_diffusion_model}")
# get the current model path being downloaded
try:
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
# saved in ~/.cache/huggingface/hub/
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
# HFValidationError, RepositoryNotFoundError
except (HFValidationError, RepositoryNotFoundError):
print(
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
)
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path, local_model_path=local_model_path, **kwargs
)
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
)
return pretrained_model
@@ -0,0 +1,198 @@
#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
"""
import os
from collections import deque
import torch
from torch import Tensor
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1 import GR00TN15
from lerobot.policies.pretrained import PreTrainedPolicy
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
name = "groot"
config_class = GrootConfig
def __init__(self, config: GrootConfig):
"""Initialize Groot policy wrapper."""
super().__init__(config)
config.validate_features()
self.config = config
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T model using Isaac-GR00T API.
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
return model
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.forward(groot_inputs)
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
loss_dict = {"loss": loss.item()}
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
"""
self.eval()
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
original_action_dim = self.config.output_features["action"].shape[0]
actions = actions[:, :, :original_action_dim]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select single action from action queue."""
self.eval()
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
@@ -0,0 +1,664 @@
#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, ProcessorMixin
else:
AutoProcessor = None
ProcessorMixin = object
from lerobot.configs.types import (
FeatureType,
NormalizationMode,
PolicyFeature,
)
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
HF_LEROBOT_HOME,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
# Defaults for Eagle processor locations
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
def make_groot_pre_post_processors(
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessor and postprocessor for Groot policy.
This creates a processing pipeline that transforms LeRobot data format into
the format expected by Isaac-GR00T models:
Preprocessing steps:
1. Optional key renaming (dataset-specific key mapping)
2. Add batch dimension to unbatched data
3. Pack video/state/action/language/embodiment and apply optional min-max normalization before padding
4. Encode video+language with Eagle VLM into intermediate eagle_content
5. Collate eagle_content into batched eagle_* tensors
6. Move tensors to device (GPU)
NOTE: We optionally apply min-max normalization to STATE and ACTION using
dataset-provided statistics prior to padding, mapping values to [-1, 1].
This mirrors SO100-style preprocessing and keeps scales consistent with GR00T.
Args:
config: Groot configuration containing data_config, embodiment_tag, etc.
dataset_stats: Optional per-key min/max statistics for normalization before padding.
Returns:
Tuple of (preprocessor, postprocessor) pipelines
"""
# Get horizon/dimension parameters from config
# These should match the config used for the pretrained model
# Default values match most GR00T configs (state_horizon=1, action_horizon=16)
state_horizon = 1
# CRITICAL: Pretrained GR00T models use action_horizon=16 max!
# The model architecture hardcodes this limit
action_horizon = min(config.chunk_size, 16)
max_state_dim = config.max_state_dim
max_action_dim = config.max_action_dim
# Pass raw dataset_stats; normalization will occur inside pack step before padding
padded_stats = dataset_stats or {}
# Define feature specs for optional normalization steps
_features: dict[str, PolicyFeature] = {
# Observation features (only add those we may normalize)
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
# Action feature
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
}
# Normalize STATE and ACTION with min_max (SO100-like default)
_norm_map = {
FeatureType.ACTION: NormalizationMode.MIN_MAX,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Determine env action dimension from config (simple, object-like PolicyFeature)
try:
env_action_dim = int(config.output_features["action"].shape[0])
except Exception:
env_action_dim = 0
input_steps: list[ProcessorStep] = [
# 1. Rename keys if needed (e.g., dataset-specific camera names)
# Leave empty for now - add mappings if your dataset uses different key names
RenameObservationsProcessorStep(rename_map={}),
# 2. Add batch dimension for single samples
AddBatchDimensionProcessorStep(),
# 3. Pack video/state/action/language/embodiment; apply optional min-max normalization before padding
GrootPackInputsStep(
state_horizon=state_horizon,
action_horizon=action_horizon,
max_state_dim=max_state_dim,
max_action_dim=max_action_dim,
language_key="task",
formalize_language=False,
embodiment_tag=config.embodiment_tag,
normalize_min_max=True,
stats=padded_stats,
),
# 4. Eagle encode (creates eagle_content)
GrootEagleEncodeStep(
tokenizer_assets_repo=config.tokenizer_assets_repo,
),
# 5. Collate eagle_content -> eagle_* tensors
GrootEagleCollateStep(
tokenizer_assets_repo=config.tokenizer_assets_repo,
),
# 6. Move to device
DeviceProcessorStep(device=config.device),
]
# Postprocessing: slice to env action dim and unnormalize to env scale, then move to CPU
output_steps: list[ProcessorStep] = [
GrootActionUnpackUnnormalizeStep(
env_action_dim=env_action_dim,
stats=padded_stats,
normalize_min_max=True,
),
# Finally, move to CPU for env interaction
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
# GR00T specific processor steps
def _to_uint8_np_bhwc(img_t: torch.Tensor) -> np.ndarray:
# img_t: (B, C, H, W) float in [0,1] or uint8
if img_t.dtype.is_floating_point:
img_t = (img_t.clamp(0, 1) * 255.0).to(torch.uint8)
return rearrange(img_t.cpu().numpy(), "b c h w -> b h w c")
def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO) -> ProcessorMixin:
# Validate that the cache directory is ready. If not, instruct the user.
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
required = [
cache_dir / "processor_config.json",
cache_dir / "preprocessor_config.json",
cache_dir / "image_processing_eagle2_5_vl_fast.py",
]
if not all(p.exists() for p in required):
raise FileNotFoundError(
f"[GROOT] Eagle processor cache at '{cache_dir}' is not populated. "
"Vendor files are copied during model creation. Create the policy/model first, "
"or call ensure_eagle_cache_ready() before building processors."
)
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
proc.tokenizer.padding_side = "left"
return proc
@dataclass
@ProcessorStepRegistry.register(name="groot_pack_inputs_v3")
class GrootPackInputsStep(ProcessorStep):
state_horizon: int = 1
action_horizon: int = 16
max_state_dim: int = 64
max_action_dim: int = 32
language_key: str = "task"
formalize_language: bool = False
embodiment_tag: str = "new_embodiment"
embodiment_mapping: dict[str, int] = field(
default_factory=lambda: {
"new_embodiment": 31, # Match original GR00T EMBODIMENT_TAG_MAPPING
"oxe_droid": 17,
"agibot_genie1": 26,
"gr1": 24,
"so100": 2,
"unitree_g1": 3,
}
)
# Min-max normalization (SO100-like) applied BEFORE padding
normalize_min_max: bool = True
stats: dict[str, dict[str, Any]] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
def _align_vec(vec: Any, target_dim: int, *, default: float) -> torch.Tensor:
t = torch.as_tensor(vec)
t = t.flatten().to(
dtype=torch.float32,
device=next(
(v.device for v in obs.values() if isinstance(v, torch.Tensor)), torch.device("cpu")
),
)
d = int(t.shape[-1]) if t.numel() > 0 else 0
if d == target_dim:
return t
if d < target_dim:
pad = torch.full((target_dim - d,), default, dtype=t.dtype, device=t.device)
return torch.cat([t, pad], dim=0)
return t[:target_dim]
def _min_max_norm(x: torch.Tensor, key: str) -> torch.Tensor:
if not self.normalize_min_max:
return x
if self.stats is None or key not in self.stats:
return x
stats_k = self.stats[key]
last_dim = x.shape[-1]
min_v = _align_vec(stats_k.get("min", torch.zeros(last_dim)), last_dim, default=0.0)
max_v = _align_vec(stats_k.get("max", torch.ones(last_dim)), last_dim, default=1.0)
denom = max_v - min_v
mask = denom != 0
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
mapped = 2 * (x - min_v) / safe_denom - 1
return torch.where(mask, mapped, torch.zeros_like(mapped))
# 1) Video (B, T=1, V, H, W, C) uint8
img_keys = sorted([k for k in obs if k.startswith("observation.images.")])
if not img_keys and "observation.image" in obs:
img_keys = ["observation.image"]
if img_keys:
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
video = np.stack(cams, axis=1) # (B, V, H, W, C)
video = np.expand_dims(video, axis=1) # (B, 1, V, H, W, C)
# GR00T validates that video.shape[3] == 3 (channels), so reorder to (B, T, V, C, H, W)
video = np.transpose(video, (0, 1, 2, 5, 3, 4)) # (B, 1, V, C, H, W)
obs["video"] = video
# Drop raw images to avoid confusion downstream
for k in img_keys:
obs.pop(k, None)
# 2) Language (string)
lang = comp.get(self.language_key)
if isinstance(lang, list):
lang = lang[0] if len(lang) > 0 else None
if not lang:
lang = "Perform the task."
if self.formalize_language:
lang = (lang or "").lower()
lang = "".join(ch for ch in lang if ch.isalnum() or ch.isspace())
comp["language"] = lang
# 3) State/state_mask -> (B, 1, max_state_dim)
if "observation.state" in obs:
state = obs["observation.state"] # (B, D)
if state.dim() != 2:
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
bsz, d = state.shape
# Normalize BEFORE padding
if self.normalize_min_max:
state = _min_max_norm(state, "observation.state")
state = state.unsqueeze(1) # (B, 1, D)
if d > self.max_state_dim:
state = state[:, :, : self.max_state_dim]
d = self.max_state_dim
elif d < self.max_state_dim:
pad = torch.zeros(bsz, 1, self.max_state_dim - d, dtype=state.dtype, device=state.device)
state = torch.cat([state, pad], dim=2)
state_mask = torch.zeros(bsz, 1, self.max_state_dim, dtype=torch.bool, device=state.device)
state_mask[:, :, :d] = True
obs["state"] = state
obs["state_mask"] = state_mask
# 4) Action/action_mask -> (B, action_horizon, max_action_dim)
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
# Normalize BEFORE temporal expansion/padding
if self.normalize_min_max:
if action.dim() == 2:
action = _min_max_norm(action, "action")
elif action.dim() == 3:
b, t, d = action.shape
flat = action.reshape(b * t, d)
flat = _min_max_norm(flat, "action")
action = flat.view(b, t, d)
if action.dim() == 2:
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
elif action.dim() == 3:
b, t, d = action.shape
if t < self.action_horizon:
last = action[:, -1:, :]
pad = last.repeat(1, self.action_horizon - t, 1)
action = torch.cat([action, pad], dim=1)
elif t > self.action_horizon:
action = action[:, : self.action_horizon, :]
else:
raise ValueError(f"action must be (B, D) or (B, T, D), got {tuple(action.shape)}")
b, t, d = action.shape
if d > self.max_action_dim:
action = action[:, :, : self.max_action_dim]
d = self.max_action_dim
elif d < self.max_action_dim:
pad = torch.zeros(b, t, self.max_action_dim - d, dtype=action.dtype, device=action.device)
action = torch.cat([action, pad], dim=2)
action_mask = torch.zeros(b, t, self.max_action_dim, dtype=torch.bool, device=action.device)
action_mask[:, :, :d] = True
transition[TransitionKey.ACTION] = action
comp["action_mask"] = action_mask
# 5) Embodiment id as LongTensor (B,)
emb_id = self.embodiment_mapping.get(self.embodiment_tag, 0)
# Infer batch size/device from any tensor in obs or action
bsz = None
device = torch.device("cpu")
for v in list(obs.values()) + [transition.get(TransitionKey.ACTION)]:
if isinstance(v, torch.Tensor):
bsz = v.shape[0]
device = v.device
break
if bsz is None and "video" in obs and isinstance(obs["video"], np.ndarray):
bsz = obs["video"].shape[0]
if bsz is None:
bsz = 1
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.long, device=device)
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
# Pipeline API requirement: declare how features change (we keep it simple)
def transform_features(self, features):
return features
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
Excludes 'stats' since they are saved separately via state_dict().
"""
return {
"state_horizon": self.state_horizon,
"action_horizon": self.action_horizon,
"max_state_dim": self.max_state_dim,
"max_action_dim": self.max_action_dim,
"language_key": self.language_key,
"formalize_language": self.formalize_language,
"embodiment_tag": self.embodiment_tag,
"embodiment_mapping": self.embodiment_mapping,
"normalize_min_max": self.normalize_min_max,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns normalization statistics as a flat state dictionary.
This enables saving stats to safetensors files, similar to normalizer_processor.
"""
if not self.stats:
return {}
flat: dict[str, torch.Tensor] = {}
for key, sub in self.stats.items():
for stat_name, value in sub.items():
tensor = torch.as_tensor(value).cpu()
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""
Loads normalization statistics from a flat state dictionary.
This enables loading stats from safetensors files during from_pretrained.
"""
if not state:
return
reconstructed: dict[str, dict[str, Any]] = {}
for flat_key, tensor in state.items():
if "." in flat_key:
key, stat_name = flat_key.rsplit(".", 1)
if key not in reconstructed:
reconstructed[key] = {}
reconstructed[key][stat_name] = tensor
if reconstructed:
self.stats = reconstructed
@dataclass
@ProcessorStepRegistry.register(name="groot_eagle_encode_v3")
class GrootEagleEncodeStep(ProcessorStep):
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
def proc(self) -> ProcessorMixin:
if self._proc is None:
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
return self._proc
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
if "video" not in obs:
return transition
video = obs["video"] # (B, T, V, H, W, C) uint8
lang = comp.get("language", "Perform the task.")
if isinstance(lang, list):
lang = lang[0] if len(lang) > 0 else "Perform the task."
bsz = video.shape[0]
eagle_contents: list[dict[str, Any]] = []
for b in range(bsz):
vt = video[b] # (T, V, C, H, W) after reorder
if vt.ndim != 5:
# Fallback: assume (T, V, H, W, C)
t, v, h, w, c = vt.shape
flat = rearrange(vt, "t v h w c -> (t v) h w c")
else:
t, v, c, h, w = vt.shape
flat = rearrange(vt, "t v c h w -> (t v) h w c")
images = [Image.fromarray(flat[i]) for i in range(t * v)]
# Format language as string list representation to match Original GROOT
lang_formatted = str([lang])
text_content = [{"type": "text", "text": lang_formatted}]
image_content = [{"type": "image", "image": img} for img in images]
conv = [{"role": "user", "content": image_content + text_content}]
text_list = [self.proc.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)]
img_inputs, vid_inputs = self.proc.process_vision_info(conv)
eagle_contents.append(
{
"text_list": text_list,
"image_inputs": img_inputs,
"video_inputs": vid_inputs,
}
)
comp["eagle_content"] = eagle_contents
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
# Pipeline API requirement: declare how features change (no schema change here)
def transform_features(self, features):
return features
# Original GR00T-style collate: converts eagle_content -> eagle_* tensors
def collate(features: list[dict[str, Any]], eagle_processor: ProcessorMixin) -> dict[str, Any]:
batch: dict[str, Any] = {}
keys = features[0].keys()
for key in keys:
values = [elem[key] for elem in features]
if key == "eagle_content":
text_list: list[str] = []
image_inputs: list[Any] = []
for v in values:
curr_text_list = v["text_list"]
curr_image_inputs = v["image_inputs"]
text_list += curr_text_list
image_inputs += curr_image_inputs
eagle_inputs = eagle_processor(
text=text_list,
images=image_inputs,
images_kwargs={"min_dynamic_tiles": 1, "max_dynamic_tiles": 1, "use_thumbnail": False},
return_tensors="pt",
padding=True,
)
for k, v in eagle_inputs.items():
k = "eagle_" + k
batch[k] = v
elif key in ("pixel_values", "image_grid_thw", "attention_mask", "input_ids"):
# Concat in existing batch dimension.
batch[key] = torch.cat(values)
else:
# state, state_mask, action and action_mask.
# Stack to form the batch dimension.
batch[key] = torch.from_numpy(np.stack(values))
return batch
@dataclass
@ProcessorStepRegistry.register(name="groot_eagle_collate_v3")
class GrootEagleCollateStep(ProcessorStep):
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
def proc(self) -> ProcessorMixin:
if self._proc is None:
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
return self._proc
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
contents = comp.get("eagle_content")
if not contents:
return transition
# Build features list as original API expects: one dict per batch item
features = [{"eagle_content": content} for content in contents]
batched = collate(features, self.proc)
# Inject eagle_* tensors and remove the temporary content and raw video to free memory
for k, v in batched.items():
comp[k] = v
comp.pop("eagle_content", None)
obs.pop(
"video", None
) # The video has been fully encoded into eagle_* tensors, so we don't need the raw video anymore
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
def transform_features(self, features):
return features
@dataclass
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0
# Apply inverse of min-max normalization if it was used in preprocessor
normalize_min_max: bool = True
stats: dict[str, dict[str, Any]] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
action = transition.get(TransitionKey.ACTION)
if not isinstance(action, torch.Tensor):
return transition
# Select last timestep and slice to env dimension
if action.dim() == 3:
action = action[:, -1, :]
# Now action is (B, D_model)
if self.env_action_dim and action.shape[-1] >= self.env_action_dim:
action = action[..., : self.env_action_dim]
# Inverse min-max normalization mirroring _min_max_norm:
# forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0
# inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min
if self.normalize_min_max and self.stats is not None:
stats_k = self.stats.get("action", {})
d = action.shape[-1]
min_v = torch.as_tensor(
stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device
)
max_v = torch.as_tensor(
stats_k.get("max", torch.ones(d)), dtype=action.dtype, device=action.device
)
if min_v.numel() != d:
min_v = torch.nn.functional.pad(min_v.flatten()[:d], (0, max(0, d - min_v.numel())))
min_v = min_v.to(action.device, dtype=action.dtype)
if max_v.numel() != d:
max_v = torch.nn.functional.pad(max_v.flatten()[:d], (0, max(0, d - max_v.numel())))
max_v = max_v.to(action.device, dtype=action.dtype)
denom = max_v - min_v
mask = denom != 0
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
inv = (action + 1.0) * 0.5 * safe_denom + min_v
action = torch.where(mask, inv, min_v)
transition[TransitionKey.ACTION] = action
return transition
def transform_features(self, features):
return features
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
Excludes 'stats' since they are saved separately via state_dict().
"""
return {
"env_action_dim": self.env_action_dim,
"normalize_min_max": self.normalize_min_max,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns normalization statistics as a flat state dictionary.
This enables saving stats to safetensors files, similar to normalizer_processor.
"""
if not self.stats:
return {}
flat: dict[str, torch.Tensor] = {}
for key, sub in self.stats.items():
for stat_name, value in sub.items():
tensor = torch.as_tensor(value).cpu()
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""
Loads normalization statistics from a flat state dictionary.
This enables loading stats from safetensors files during from_pretrained.
"""
if not state:
return
reconstructed: dict[str, dict[str, Any]] = {}
for flat_key, tensor in state.items():
if "." in flat_key:
key, stat_name = flat_key.rsplit(".", 1)
if key not in reconstructed:
reconstructed[key] = {}
reconstructed[key][stat_name] = tensor
if reconstructed:
self.stats = reconstructed
+47
View File
@@ -0,0 +1,47 @@
from pathlib import Path
from shutil import copytree
from huggingface_hub import hf_hub_download
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
"""
cache_dir = Path(cache_dir)
vendor_dir = Path(vendor_dir)
try:
# Populate/refresh cache with vendor files to ensure a complete processor directory
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
required_assets = [
"vocab.json",
"merges.txt",
"added_tokens.json",
"chat_template.json",
"special_tokens_map.json",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"processor_config.json",
"tokenizer_config.json",
]
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
for fname in required_assets:
dst = cache_dir / fname
if not dst.exists():
print(f"[GROOT] Fetching {fname}")
hf_hub_download(
repo_id=assets_repo,
filename=fname,
repo_type="model",
local_dir=str(cache_dir),
)
@@ -485,6 +485,7 @@ class VLAFlowMatching(nn.Module):
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
device=self.config.device,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
@@ -70,13 +70,14 @@ class SmolVLMWithExpertModel(nn.Module):
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
device: str = "auto",
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
device_map=device,
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
+41
View File
@@ -22,6 +22,8 @@ import numpy as np
import torch
from torch import nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.utils import build_dataset_frame
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
@@ -198,3 +200,42 @@ def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict])
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
}
return act_processed_policy
def raise_feature_mismatch_error(
provided_features: set[str],
expected_features: set[str],
) -> None:
"""
Raises a standardized ValueError for feature mismatches between dataset/environment and policy config.
"""
missing = expected_features - provided_features
extra = provided_features - expected_features
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
raise ValueError(
f"Feature mismatch between dataset/environment and policy config.\n"
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
f"Please ensure your dataset and policy use consistent feature names.\n"
f"If your dataset uses different observation keys (e.g., cameras named differently), "
f"use the `--rename_map` argument, for example:\n"
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
f'"observation.images.top": "observation.images.camera2"}}\''
)
def validate_visual_features_consistency(
cfg: PreTrainedConfig,
features: dict[str, PolicyFeature],
) -> None:
"""
Validates visual feature consistency between a policy config and provided dataset/environment features.
Args:
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
"""
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
if not provided_visuals.issubset(expected_visuals):
raise_feature_mismatch_error(provided_visuals, expected_visuals)
+154
View File
@@ -0,0 +1,154 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):
return self._process_observation(observation)
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
Only accepts torch tensors of shape (B, 4).
Args:
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
Returns:
Tensor: (B, 3) axis-angle vectors
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 4)
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)
return result
+1 -1
View File
@@ -65,7 +65,7 @@ def main(cfg: TrainRLServerPipelineConfig):
# env_cfg=cfg.env,
ds_meta=dataset_meta,
)
policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
policy = policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
policy.eval()
eval_policy(env, policy=policy, n_episodes=10)
+1 -1
View File
@@ -153,7 +153,7 @@ class LeKiwi(Robot):
homing_offsets.update(dict.fromkeys(self.base_motors, 0))
full_turn_motor = [
motor for motor in motors if any(keyword in motor for keyword in ["wheel", "wrist"])
motor for motor in motors if any(keyword in motor for keyword in ["wheel", "wrist_roll"])
]
unknown_range_motors = [motor for motor in motors if motor not in full_turn_motor]
+14 -5
View File
@@ -18,14 +18,24 @@ import base64
import json
import logging
import time
from dataclasses import dataclass, field
import cv2
import draccus
import zmq
from .config_lekiwi import LeKiwiConfig, LeKiwiHostConfig
from .lekiwi import LeKiwi
@dataclass
class LeKiwiServerConfig:
"""Configuration for the LeKiwi host script."""
robot: LeKiwiConfig = field(default_factory=LeKiwiConfig)
host: LeKiwiHostConfig = field(default_factory=LeKiwiHostConfig)
class LeKiwiHost:
def __init__(self, config: LeKiwiHostConfig):
self.zmq_context = zmq.Context()
@@ -47,17 +57,16 @@ class LeKiwiHost:
self.zmq_context.term()
def main():
@draccus.wrap()
def main(cfg: LeKiwiServerConfig):
logging.info("Configuring LeKiwi")
robot_config = LeKiwiConfig()
robot = LeKiwi(robot_config)
robot = LeKiwi(cfg.robot)
logging.info("Connecting LeKiwi")
robot.connect()
logging.info("Starting HostAgent")
host_config = LeKiwiHostConfig()
host = LeKiwiHost(host_config)
host = LeKiwiHost(cfg.host)
last_cmd_time = time.time()
watchdog_active = False
-173
View File
@@ -1,173 +0,0 @@
This tutorial explains how to use [Stretch 3](https://hello-robot.com/stretch-3-product) with LeRobot.
## Setup
Familiarize yourself with Stretch by following its [tutorials](https://docs.hello-robot.com/0.3/getting_started/hello_robot/) (recommended).
To use LeRobot on Stretch, 3 options are available:
- [tethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup)
- [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup)
- ssh directly into Stretch (you will first need to install and configure openssh-server on stretch using one of the two above setups)
## Install LeRobot
On Stretch's CLI, follow these steps:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Comment out these lines in `~/.profile` (this can mess up paths used by conda and ~/.local/bin should already be in your PATH)
```
# set PATH so it includes user's private bin if it exists
if [ -d "$HOME/.local/bin" ] ; then
PATH="$HOME/.local/bin:$PATH"
fi
```
3. Restart shell or `source ~/.bashrc`
4. Create and activate a fresh conda environment for lerobot
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
5. Clone LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
6. When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
```
7. Install LeRobot with stretch dependencies:
```bash
cd ~/lerobot && pip install -e ".[stretch]"
```
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
```bash
stretch_system_check.py
```
> **Note:** You may need to free the "robot process" after booting Stretch by running `stretch_free_robot_process.py`. For more info this Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#turning-off-gamepad-teleoperation).
You should get something like this:
```bash
For use with S T R E T C H (R) from Hello Robot Inc.
---------------------------------------------------------------------
Model = Stretch 3
Tool = DexWrist 3 w/ Gripper
Serial Number = stretch-se3-3054
---- Checking Hardware ----
[Pass] Comms are ready
[Pass] Actuators are ready
[Warn] Sensors not ready (IMU AZ = -10.19 out of range -10.1 to -9.5)
[Pass] Battery voltage is 13.6 V
---- Checking Software ----
[Pass] Ubuntu 22.04 is ready
[Pass] All APT pkgs are setup correctly
[Pass] Firmware is up-to-date
[Pass] Python pkgs are up-to-date
[Pass] ROS2 Humble is ready
```
## Teleoperate, record a dataset and run a policy
**Calibrate (Optional)**
Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=calibrate
```
This is equivalent to running `stretch_robot_home.py`
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
**Teleoperate**
Before trying teleoperation, you need to activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
Now try out teleoperation (see above documentation to learn about the gamepad controls):
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=teleoperate
```
This is essentially the same as running `stretch_gamepad_teleop.py`
**Record a dataset**
Once you're familiar with the gamepad controls and after a bit of practice, you can try to record your first dataset with Stretch.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record one episode:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/stretch_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
```
> **Note:** If you're using ssh to connect to Stretch and run this script, you won't be able to visualize its cameras feed (though they will still be recording). To see the cameras stream, use [tethered](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup) or [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup).
**Replay an episode**
Now try to replay this episode (make sure the robot's initial position is the same):
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/stretch_test \
--control.episode=0
```
If you need help, please reach out on Discord in the channel `#stretch3-mobile-arm`.
@@ -1,51 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.cameras.realsense import RealSenseCameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("stretch3")
@dataclass
class Stretch3RobotConfig(RobotConfig):
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"navigation": OpenCVCameraConfig(
index_or_path="/dev/hello-nav-head-camera",
fps=10,
width=1280,
height=720,
rotation=-90,
),
"head": RealSenseCameraConfig(
name="Intel RealSense D435I",
fps=30,
width=640,
height=480,
rotation=90,
),
"wrist": RealSenseCameraConfig(
name="Intel RealSense D405",
fps=30,
width=640,
height=480,
),
}
)
@@ -1,180 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import numpy as np
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.datasets.utils import get_nested_item
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from ..robot import Robot
from .configuration_stretch3 import Stretch3RobotConfig
# {lerobot_keys: stretch.api.keys}
STRETCH_MOTORS = {
"head_pan.pos": "head.head_pan.pos",
"head_tilt.pos": "head.head_tilt.pos",
"lift.pos": "lift.pos",
"arm.pos": "arm.pos",
"wrist_pitch.pos": "end_of_arm.wrist_pitch.pos",
"wrist_roll.pos": "end_of_arm.wrist_roll.pos",
"wrist_yaw.pos": "end_of_arm.wrist_yaw.pos",
"gripper.pos": "end_of_arm.stretch_gripper.pos",
"base_x.vel": "base.x_vel",
"base_y.vel": "base.y_vel",
"base_theta.vel": "base.theta_vel",
}
class Stretch3Robot(Robot):
"""[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
config_class = Stretch3RobotConfig
name = "stretch3"
def __init__(self, config: Stretch3RobotConfig):
raise NotImplementedError
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.api = StretchAPI()
self.cameras = make_cameras_from_configs(config.cameras)
self.is_connected = False
self.logs = {}
self.teleop = None # TODO remove
# TODO(aliberts): test this
RobotParams.set_logging_level("WARNING")
RobotParams.set_logging_formatter("brief_console_formatter")
self.state_keys = None
self.action_keys = None
@property
def observation_features(self) -> dict:
return {
"dtype": "float32",
"shape": (len(STRETCH_MOTORS),),
"names": {"motors": list(STRETCH_MOTORS)},
}
@property
def action_features(self) -> dict:
return self.observation_features
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
def connect(self) -> None:
self.is_connected = self.api.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError()
for cam in self.cameras.values():
cam.connect()
self.is_connected = self.is_connected and cam.is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.calibrate()
def calibrate(self) -> None:
if not self.api.is_homed():
self.api.home()
def _get_state(self) -> dict:
status = self.api.get_status()
return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict = {}
# Read Stretch state
before_read_t = time.perf_counter()
state = self._get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
self.state_keys = list(state)
state = np.asarray(list(state.values()))
obs_dict[OBS_STATE] = state
# Capture images from cameras
for cam_key, cam in self.cameras.items():
before_camread_t = time.perf_counter()
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
return obs_dict
def send_action(self, action: np.ndarray) -> np.ndarray:
if not self.is_connected:
raise ConnectionError()
if self.teleop is None:
self.teleop = GamePadTeleop(robot_instance=False)
self.teleop.startup(robot=self)
if self.action_keys is None:
dummy_action = self.teleop.gamepad_controller.get_state()
self.action_keys = list(dummy_action.keys())
action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
before_write_t = time.perf_counter()
self.teleop.do_motion(state=action_dict, robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
# TODO(aliberts): return action_sent when motion is limited
return action
def teleop_safety_stop(self) -> None:
if self.teleop is not None:
self.teleop._safety_stop(robot=self)
def disconnect(self) -> None:
self.api.stop()
if self.teleop is not None:
self.teleop.gamepad_controller.stop()
self.teleop.stop()
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False
-8
View File
@@ -40,14 +40,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .lekiwi import LeKiwi
return LeKiwi(config)
elif config.type == "stretch3":
from .stretch3 import Stretch3Robot
return Stretch3Robot(config)
elif config.type == "viperx":
from .viperx import ViperX
return ViperX(config)
elif config.type == "hope_jr_hand":
from .hope_jr import HopeJrHand
-196
View File
@@ -1,196 +0,0 @@
This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot.
## Setup
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
## Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Restart shell or `source ~/.bashrc`
3. Create and activate a fresh conda environment for lerobot
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
4. Clone LeRobot:
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
```
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
```bash
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
```
## Teleoperate
\*\*/!\ FOR SAFETY, READ THIS /!\*\*
Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly:
1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms,
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
By running the following code, you can start your first **SAFE** teleoperation:
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=5 \
--control.type=teleoperate
```
By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`ViperXConfig`](./config_viperx.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=null \
--control.type=teleoperate
```
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with Aloha.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record 2 episodes and upload your dataset to the hub:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=null \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/aloha_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=2 \
--control.push_to_hub=true
```
## Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
echo ${HF_USER}/aloha_test
```
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun):
```bash
lerobot-dataset-viz \
--repo-id ${HF_USER}/aloha_test --episode 0
```
## Replay an episode
\*\*/!\ FOR SAFETY, READ THIS /!\*\*
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above.
Now try to replay the first episode on your robot:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--robot.max_relative_target=null \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/aloha_test \
--control.episode=0
```
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/aloha_test \
--policy.type=act \
--output_dir=outputs/train/act_aloha_test \
--job_name=act_aloha_test \
--policy.device=cuda \
--wandb.enable=true
```
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../src/lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=${HF_USER}/eval_act_aloha_test \
--control.tags='["tutorial"]' \
--control.warmup_time_s=5 \
--control.episode_time_s=30 \
--control.reset_time_s=30 \
--control.num_episodes=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \
--control.num_image_writer_processes=1
```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
## More
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.

Some files were not shown because too many files have changed in this diff Show More