Compare commits

..

35 Commits

Author SHA1 Message Date
hf-secutity-analysis[bot] 102810593f fix(security): remediate workflow vulnerability in .github/workflows/full_tests.yml 2026-03-06 09:15:53 +00:00
Steven Palma e489ba24fc feat(dependencies): require Python 3.12+ as minimum version (#3023)
* feat(dependecies): upgrade to python3.12

* fix(test): processor regex message

* fix(test): processor regex message

* fix(dependecies): resolve all tags in python 3.12

* fix(dependecies): add more hints to faster resolve

* chore(dependecies): remove cli tag huggingface-hub dep

* refactor(policy): update eagle for python3.12

* chore(docs): update policy creation for python 3.12

* chore(test): skip failing tests in macos
2026-03-06 10:15:13 +01:00
Steven Palma d324ffe810 fix(ci): test only multi-gpu tests in multi-gpu runner (#3092) 2026-03-05 19:53:40 +01:00
Pepijn 1a24f770d3 Feat/slurm compute rabc script (#3041)
* Add SLURM SARM progress annotation script.

Provide a standalone two-stage compute/aggregate pipeline for RA-BC progress generation so large datasets can be processed in parallel and optionally uploaded to the Hub.

Made-with: Cursor

* fix pr comments

* remove comments
2026-03-05 18:27:58 +01:00
Caroline Pascal 92fba37225 fix(num_frames): fixing redundant frames count in conversion script (#3091) 2026-03-05 15:49:50 +01:00
Steven Palma 3e45120272 fix(ci): log in HF for gated repo in nightly workflows (#3089)
* fix(ci): log in HF for gated repo in nightly workflows

* fix(ci): add env var

* fix(ci): remove 10 min limit for multi-gpu nightly
2026-03-05 13:22:37 +01:00
Steven Palma f0d2b37beb chore(dependencies): bump transformers v5 (#2964)
* chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy

* chore(dependencies): bump pi0 family to transformers v5

* chore(dependencies): bump wall x to transformers v5

* chore(dependencies): bump gr00t to transformers v5

* chore(style): fix pre-commit

* fix(policy): xvla forced_bos_token missing

* test(rl): skip ci tests for resnet10

* Fix: full pi models support for transformer v5 (#2967)

* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>

* chore: fix XVLA in transformers v5 (#3006)

* test(policies): enable wall x CI testing

* style(test): pre-commit check

* style(test): pre-commit

* fix wall x for transformer v5 (#3008)

* tv5 fix

* various wall x fixes

* Delete tests/policies/pi0_pi05/print_pi05_output_logits.py

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* sync modeling_florence2.py with chore/bump_transformers_v5

* more

* more fixes

* more

* remove comment

* more

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* chore(dependencies): adjust dependencies versioning after transformers v5 (#3034)

* chore(dependecies): adjust dependecies versioning after transformers v5

* fix(policies): remove deprecated input_embeds

* fix(policies): dict _tied_weights_keys

* chore(depedencies): common qwen-vl-utils

* chore(dependencies): bump transformers to 5.2

* Fix policy testing for tv5 (#3032)

* fix ci logger

* other fix

* fix mypy

* change logits to torch2.10

* skip wallx|

* remove logging

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>

* feat(ci): log into HF to unblock some CI tests (#3007)

* feat(ci): log into HF to unblock some CI tests

* chore(ci): change hf call + secret name

* fix(ci): temp fix for pi0 rtc test

* test(policies): require_cuda for unblocked tests

* test(policies): require_cuda wall_x

* fic(tests): require_cuda outter most for pi0

* fix(test): return instead of yield

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* style(test): fix pre-commit

* chore(deps): upgrade transformers (#3050)

* chore(test): use lerobot model

* fix(policies): change default action tokenizer for wall x

* sample on cpu

* Revert "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5"

This reverts commit d9b76755f7, reversing
changes made to 89359cb0b6.

* Reapply "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5"

This reverts commit c9914db78b.

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Yufei Sun <skieyfly@gmail.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2026-03-05 09:25:26 +01:00
Caroline Pascal cbc8bfb2e6 chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label (#3077)
* chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label

* chore(task): renamming the default index label in the tasks DataFrame to task

* Revert "chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label"

This reverts commit f55de3255278f23f18b5d955565f6768d094951d.

* chore(docstrings): updating docstrings to match dataset v3.0 architecture

* chore(format): formatting code
2026-03-04 17:59:03 +01:00
Paul Crook 0d1be72dc8 Fixing metadata indexing when writing new Parquet file (#2941)
* Fixing metadata indexing when writing new Parquet file

Summary:
  - addressing this issue: https://github.com/huggingface/lerobot/issues/2401
  - vibe-coded bugfix by Claude Sonnet 4.5

* Backing out changes to convert_videos_of_camera

* Addressing Ruff pre-commit complaint

Summary:
 - addressing "SIM113 Use `enumerate()` for index variable `ep_idx` in `for` loop"

---------

Co-authored-by: Paul <238953601+pac-robotics@users.noreply.github.com>
2026-03-04 16:53:34 +01:00
Maxime Ellerbach 96b7c212c4 chore(docs): updating deprecated huggingface-cli to hf (#3071)
* chore(docs): updating deprecated huggingface-cli to hf

* small typo in my-org
2026-03-04 15:08:49 +01:00
Caroline Pascal 4303b3c930 chore(root): fixing root semantics in convert_dataset script (#3073)
* fix(root): fixing root semantincs in convert_dataset script

* fix(\): fixing command syntax in dataset conversion script

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

---------

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-03-04 11:11:21 +01:00
Caroline Pascal 63dca86df8 fix(dataset edit tools): clarifying root argument usage + adding related features (#3049)
* fix(root): adding proper support for the root and new_root arguments

* feat(roots): adding a roots agrument for the merge operation

* chore(clean): cleaning up code

* chore(doctrings): updating doctrings with new features

* fix(repo_id): setting repo_id to None when not needed

* fix(roots/repo_ids): making mypy happy by using repo_ids and roots for merge operation

* fix(path): fixing path related issues

* fix(repo_id): fixing issues related to repo_id

* chore(doctrings): updating docstrings + fix typo

* chore(clean): cleaning code

* fix(split new_repo_id): reverting new_repo_id addition for split operation

* docs(dosctrings): completing docstrings

* fix(repo_ids/roots): improving checks for repo_ids/roots lengths

* fix(repo_ids): making repo_ids optional in MergeConfig but raise if not given

* fix(docstrings): fixing docstrings for split operation

* fix(hints): updating get_output_path hints to accept paths as strings too

* fix(y/N prompts): removing y/N prompts in lerobot_edit_dataset

* fix(merge repo_id): fixing merge operation to use new_repo_id instead of repo_id

* fix(typo): fixing typo in doctrings
2026-03-03 15:40:46 +01:00
Caroline Pascal 8a0cc3d664 fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets (#3068)
* fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets

* fix(segfault risk): removing segfault risk by calling  batch["index"] in the dataloader loop
2026-03-03 11:55:09 +01:00
Bernie Telles 8bb8ed4803 Improve policy_device documentation for async.mdx (#3060) 2026-03-02 15:35:15 +01:00
Steven Palma 095856b06a chore: add AI policy (#3055) 2026-02-28 14:41:28 +01:00
Steven Palma 563f42bdb1 chore(dependencies): Bump lerobot to 0.4.5 (#3051) 2026-02-27 19:29:35 +01:00
Caroline Pascal 8fff0fde7c chore(docstrings): fixing deprecated root argument description in LeRobotDataset class (#3035)
* chore(docstrings): fixing deprecated `root` argument docstrings in LeRobotDataset class

* chore(draccus): updating draccus CLI help

* chore(revert): reverting changes in lerobot_dataset_viz.py

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:22:44 +01:00
Pepijn 04de496547 fix(logging): avoid double-counting samples across processes (#3045) 2026-02-27 17:45:19 +01:00
Khalil Meftah baf9b50365 Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046)
* refactor(diffusion): replace crop_shape with resize_shape and crop_ratio

* fix(diffusion): address review feedback on resize/crop backward compat

* test: regenerate diffusion artifacts for updated default config

* fix: disable crop when resize path uses crop_ratio=1.0

---------

Co-authored-by: starlitxiling <1754165401@qq.com>
2026-02-27 17:44:53 +01:00
Jade Choghari a0fdbf037a feat(policies): add Smolvla torch compile support (#3043)
* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* pre-commit run

* Add torch.compile for smolvla

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Add torch.compile for smolvla

Add model compilation option for improved performance.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* first

---------

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:58:36 +03:00
Khalil Meftah c085531b17 fix: add missing openarm_mini import to CLI scripts (#3028) 2026-02-27 15:46:31 +01:00
Steven Palma c7c6205332 chore(scripts): no spam log when no action (#3042) 2026-02-27 15:26:56 +01:00
Michio Sun 4e54be1334 fix(datasets): skip warning when MultiLeRobotDataset features are identical (#3019)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 17:42:22 +01:00
Damien LaRocque fde9d08281 feat(async_inference) Enable plugins with async inference (#2425)
* feat(async-inference) Try using async inference server with plugins

* Fix import

* Fix import error in Robot Client

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 14:41:32 +01:00
Khalil Meftah 46044fed75 Fix: remove device_map from SmolVLA model loading (#3029)
* Fix SmolVLA meta tensor error by removing device_map

- Remove device_map parameter from VLM model loading
- Change torch_dtype from string to torch.bfloat16
- Add explicit .to(device) calls after initialization

This resolves NotImplementedError when training SmolVLA policy.
Fixes meta tensor copy issue in factory.py:418.

* fix: remove manual device movement logic and fix dtype handling

---------

Co-authored-by: Highsky7 <albert31115@gmail.com>
2026-02-26 13:28:46 +01:00
Khalil Meftah 975dcad918 Feat(teleoperators): add OpenArm Mini teleoperator (#3022)
* add OpenArm Mini config and module init

* add OpenArm Mini teleoperator implementation

* add OpenArm Mini into factory and setup motors

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-02-25 18:46:55 +01:00
Cotton Hu d0b58190da fix(policies): support dp train when n_obs_steps=1 (#2430)
Co-authored-by: hukongtao <hukongtao@agibot.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-25 17:36:31 +01:00
Mishig 9a5ab8ffab feat: add visualization badge to card template and update dataset card creation with repo_id (#3005)
* feat: add visualization badge to card template and update dataset card creation with repo_id

* Update src/lerobot/datasets/card_template.md

* Update src/lerobot/datasets/card_template.md

---------

Signed-off-by: Mishig <dmishig@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-25 16:02:40 +01:00
Khalil Meftah 7541d72130 Fix SARM dense_only mode: always load episodes_df for target computation (#3021)
* fix annotation mode check

* fix: SARM dense_only mode always load episodes_df for target computation

---------

Co-authored-by: John Newsom <jackmnewsom@gmail.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-02-25 13:28:01 +01:00
Jash Shah 0317a15bf1 fix(video): replace assertions with proper exceptions in video frame decoding (#3016)
Replaced assert statements with FrameTimestampError exceptions in
decode_video_frames_torchvision and decode_video_frames_torchcodec.

Assertions are unsuitable for runtime validation because they can be
silently disabled with python -O, and they produce unhelpful
AssertionError tracebacks. The codebase already defines
FrameTimestampError for this exact purpose but it was only used
in one of the three validation sites.

Also removed AssertionError from the except clause in
LeRobotDataset.__init__, which was masking video timestamp errors
by silently triggering a dataset re-download instead of surfacing
the actual problem.
2026-02-25 12:29:22 +01:00
Jash Shah f138e5948a Fix metaworld_config.json not bundled in pip installs and AttributeError crash (#3017)
1. Include metaworld_config.json in package distributions by adding it to
   both MANIFEST.in (for sdist) and pyproject.toml package-data (for wheels).
   Without this, pip-installed lerobot raises FileNotFoundError when
   importing the metaworld environment.

2. Fix crash in sanity_check_dataset_name where the error message accesses
   policy_cfg.type when policy_cfg is None, raising AttributeError instead
   of the intended ValueError.

Fixes #2958
2026-02-25 12:29:10 +01:00
Martin Kiefel 8fef4ddab8 fix(dataset): Fix reindexing bug for videos on splits (#2548)
* fix(dataset): Reindex videos based on frame and not on time

Sometimes during split operations the frame timestamp floating
precision leads to frame ending up in the wrong split.

This changes fixes the issues by directly working with frame indices
instead.

* Fix formatting
2026-02-25 11:57:07 +01:00
Steven Palma 18d9cb5ac4 feat(scripts): Integrate tqdm for training progress visualization (#3010) 2026-02-24 19:10:43 +01:00
Steven Palma 5095ab0845 fix(ci): permissions triton (#3011) 2026-02-24 19:09:34 +01:00
Jash Shah dac1efd13d feat: Enable torch.compile for DiffusionPolicy inference (#2486)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-24 17:29:08 +01:00
107 changed files with 2170 additions and 2172 deletions
+7 -1
View File
@@ -44,7 +44,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
@@ -61,6 +61,7 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -89,5 +90,10 @@ jobs:
- name: Install lerobot with test extras
run: uv sync --extra "test"
- name: Login to Hugging Face
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest
run: uv run pytest tests -vv --maxfail=10
+15 -3
View File
@@ -37,7 +37,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest action is built, canceling older runs.
@@ -60,6 +60,7 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -87,6 +88,11 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
@@ -162,6 +168,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -173,6 +180,12 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Fix ptxas permissions
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -187,7 +200,6 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Get Docker Hub Token and Delete Image
# zizmor: ignore[template-injection]
env:
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
@@ -219,4 +231,4 @@ jobs:
exit 1
fi
# TODO(Steven): Check dockerimages pull in ubuntu
# TODO(Steven): Check dockerimages pull in ubuntu
+17 -4
View File
@@ -28,7 +28,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
@@ -119,6 +119,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
@@ -130,6 +131,10 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on CPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -146,6 +151,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -157,6 +163,10 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -174,6 +184,7 @@ jobs:
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
CUDA_VISIBLE_DEVICES: "0,1,2,3"
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -185,12 +196,14 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Verify GPU availability
run: |
nvidia-smi
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10
run: pytest -vv tests/training/
+1 -1
View File
@@ -50,7 +50,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
- name: Run pre-commit hooks
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
+2 -2
View File
@@ -22,7 +22,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
jobs:
# This job builds the Python package and publishes it to PyPI
@@ -45,7 +45,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
- name: Extract Version
id: extract_info
+11 -2
View File
@@ -29,7 +29,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
# Ensures that only the latest action is built, canceling older runs.
@@ -48,6 +48,7 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -79,7 +80,10 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv
@@ -137,6 +141,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -148,6 +153,10 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv
- name: Run end-to-end tests
+2 -2
View File
@@ -13,7 +13,7 @@
# limitations under the License.
default_language_version:
python: python3.10
python: python3.12
exclude: "tests/artifacts/.*\\.safetensors$"
@@ -55,7 +55,7 @@ repos:
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]
args: [--py312-plus]
##### Markdown Quality #####
- repo: https://github.com/rbubley/mirrors-prettier
+25
View File
@@ -0,0 +1,25 @@
# AI Usage Policy
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
## Remember the Human Maintainers
Please remember that LeRobot is maintained by a dedicated team of humans.
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
## AI is Welcome Here
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
+1 -1
View File
@@ -2,7 +2,7 @@
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
## Ways to Contribute
+1
View File
@@ -1,2 +1,3 @@
include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json
+3 -1
View File
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
# Define Python version argument
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.12
# Configure environment variables
ENV DEBIAN_FRONTEND=noninteractive \
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
RUN uv pip install --no-cache ".[all]"
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
# Copy the rest of the application source code
# Make sure to have the git-LFS files for testing
COPY --chown=user_lerobot:user_lerobot . .
+1 -1
View File
@@ -19,7 +19,7 @@
# docker run -it --rm lerobot-user
# Configure the base image
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.12
FROM python:${PYTHON_VERSION}-slim
# Configure environment variables
-2
View File
@@ -17,8 +17,6 @@
title: Train RL in Simulation
- local: multi_gpu_training
title: Multi GPU training
- local: hil_collection
title: Human In the Loop Data Collection
- local: peft_training
title: Training with PEFT (e.g., LoRA)
title: "Tutorials"
+1 -1
View File
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
+4 -4
View File
@@ -32,7 +32,7 @@ version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.11"
requires-python = ">= 3.12"
[build-system]
build-backend = # your-build-backend
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
# modeling_my_custom_policy.py
import torch
import torch.nn as nn
from typing import Dict, Any
from typing import Any
from lerobot.policies.pretrained import PreTrainedPolicy
from .configuration_my_custom_policy import MyCustomPolicyConfig
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
...
```
@@ -102,7 +102,7 @@ Create processor functions:
```python
# processor_my_custom_policy.py
from typing import Dict, Any
from typing import Any
import torch
+3 -3
View File
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
### Hardware
- EarthRover Mini robot
- Computer with Python 3.10 or newer
- Computer with Python 3.12 or newer
- Internet connection
### Setting Up the Frodobots SDK
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face username:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
+2 -2
View File
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
pip install huggingface_hub
# Login to Hugging Face
huggingface-cli login
hf auth login
# Create a new repository
huggingface-cli repo create my-custom-env --type space --org my-org
hf repo create my-org/my-custom-env
# Initialize git and push
git init
-237
View File
@@ -1,237 +0,0 @@
# Human-In-the-Loop Data Collection
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data — recovery movements and corrections — is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
---
## Why Human-In-the-Loop?
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
- Running the trained policy on the real robot
- Having a human intervene when the robot is about to fail
- Recording the human's recovery and correction as training data
- Fine-tuning the policy on the combined dataset
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
---
## How It Works
During a HIL session, the human operator follows this loop within each episode:
1. **Watch** the policy run autonomously
2. **Pause** when failure is imminent — the robot holds its position
3. **Take control** — teleoperate the robot back to a good state (recovery), then correct the behavior
4. **Return control to the policy** — the policy resumes autonomous execution
5. Repeat steps 24 as many times as needed during the episode
6. **End the episode** when the task is complete, save and move on to the next rollout
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
This process can be repeated iteratively: deploy, collect, fine-tune, repeat — each round targeting the current policy's failure modes.
```
┌─────────────────────────────────────────────────────────────────────────┐
│ Policy v0 (trained on demos) │
│ ↓ │
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
│ ↓ │
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
│ ↓ │
│ ... (repeat until satisfactory performance) │
└─────────────────────────────────────────────────────────────────────────┘
```
---
## Hardware Requirements
### Teleoperator Requirements
The HIL data collection scripts require **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators:**
- `so101_leader` - SO-101 Leader Arm
- `openarms_mini` - OpenArms Mini (via third-party plugin)
---
## Scripts
Two scripts are provided depending on your policy's inference speed:
| Script | Use Case | Models |
| ---------------------------- | ------------------------------------------ | --------------------- |
| `hil_data_collection.py` | Standard synchronous inference | ACT, Diffusion Policy |
| `hil_data_collection_rtc.py` | Real-Time Chunking for high-latency models | Pi0, Pi0.5, SmolVLA |
---
## Step-by-Step Guide
### Step 1: Pre-train a Base Policy
First, train a policy on your demonstration dataset:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/demo-dataset \
--policy.type=pi0 \
--output_dir=outputs/pretrain \
--batch_size=32 \
--steps=50000
```
### Step 2: Collect HIL Data
**Standard inference (ACT, Diffusion Policy):**
```bash
python examples/rac/hil_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-dataset \
--dataset.single_task="Pick up the cube and place it in the bowl" \
--dataset.num_episodes=50
```
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
For models with high inference latency, use the RTC script for smooth execution:
```bash
python examples/rac/hil_data_collection_rtc.py \
--robot.type=so100_follower \
--teleop.type=so100_leader \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.single_task="Pick up the cube" \
--rtc.execution_horizon=20 \
--interpolation=true
```
**Controls (Conceptual):**
The interaction model is:
- **Pause input**: pause autonomous policy execution
- **Takeover input**: transfer control to the human operator and record intervention data
- **Return-to-policy input**: hand control back to the policy and continue the same episode
- **Episode control inputs**: save/re-record/stop/reset as needed
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
**The HIL Protocol:**
1. Watch the policy run autonomously (teleop is idle/free)
2. When you see imminent failure, trigger the **pause input**
- Policy stops
- Teleoperator moves to match robot position (torque enabled)
- No frames recorded during pause
3. Trigger the **takeover input** to take control
- Teleoperator torque disabled, free to move
- **Recovery**: Teleoperate the robot back to a good state
- **Correction**: Correct the behavior
- All movements are recorded
4. Trigger the **return-to-policy input**
- Policy resumes autonomous execution from the current state
- You can intervene again at any time (repeat steps 24)
5. End and save the episode when the task is complete (or episode time limit is reached)
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
7. Start the next episode
**Foot Pedal Setup (Linux):**
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
```bash
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
```
### Step 3: Fine-tune the Policy
Fine-tune on the combined demonstration + HIL data:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/hil-dataset \
--policy.type=pi0 \
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
--output_dir=outputs/hil_finetune \
--steps=20000
```
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
---
## Tips for Effective HIL Collection
### When to Intervene
Intervene when you see:
- Robot about to make an irreversible mistake
- Robot hesitating or showing uncertain behavior
- Robot deviating from the expected trajectory
### Recovery: Teleoperating Back to a Good State
During recovery, teleoperate the robot back to a state where:
- The robot is in a familiar, in-distribution configuration
- The current subtask can still be completed
- The recovery trajectory itself is informative training data
### Quality of Corrections
During correction:
- Provide **confident, clean** trajectories
- Complete the current subtask fully
- Don't overcorrect or add unnecessary movements
---
## Related Work
This HIL data collection approach builds on ideas from interactive imitation learning, including DAgger (Ross et al., 2011), HG-DAgger (Kelly et al., 2019), RaC (Hu et al., 2025), and RECAP (Physical Intelligence, 2025). See those works for a deeper treatment of the theory behind human-in-the-loop policy improvement.
```bibtex
@article{ross2011dagger,
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
year={2011}
}
@article{kelly2019hgdagger,
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
journal={arXiv preprint arXiv:1810.02890},
year={2019}
}
@article{hu2025rac,
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
journal={arXiv preprint arXiv:2509.07953},
year={2025}
}
@article{pi2025recap,
title={π0.6: a VLA That Learns From Experience},
author={Physical Intelligence},
year={2025}
}
```
+4 -4
View File
@@ -159,7 +159,7 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
@@ -327,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
You can also push your local dataset to the Hub manually, running:
```bash
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
```
#### Record function
@@ -491,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
Once training is done, upload the latest checkpoint with:
```bash
huggingface-cli upload ${HF_USER}/act_so101_test \
hf upload ${HF_USER}/act_so101_test \
outputs/train/act_so101_test/checkpoints/last/pretrained_model
```
@@ -499,7 +499,7 @@ You can also upload intermediate checkpoints with:
```bash
CKPT=010000
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
hf upload ${HF_USER}/act_so101_test${CKPT} \
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
```
+3 -3
View File
@@ -1,6 +1,6 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
@@ -11,10 +11,10 @@ bash Miniforge3-$(uname)-$(uname -m).sh
## Step 2: Environment Setup
Create a virtual environment with Python 3.10, using conda:
Create a virtual environment with Python 3.12, using conda:
```bash
conda create -y -n lerobot python=3.10
conda create -y -n lerobot python=3.12
```
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
+2 -2
View File
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
+10 -10
View File
@@ -52,7 +52,7 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
You have two options for the FAST tokenizer:
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
@@ -114,15 +114,15 @@ lerobot-train \
### Key Training Parameters
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
## Inference
+2 -2
View File
@@ -123,7 +123,7 @@ SSH into the robot and install LeRobot:
```bash
ssh unitree@<YOUR_ROBOT_IP>
conda create -y -n lerobot python=3.10
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
@@ -153,7 +153,7 @@ With the robot server running, you can now control the robot remotely. Let's lau
### Step 1: Install LeRobot on your machine
```bash
conda create -y -n lerobot python=3.10
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
+1 -1
View File
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
+490
View File
@@ -0,0 +1,490 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SLURM-distributed SARM RA-BC annotation pipeline.
Computes SARM progress values for all frames in a dataset, distributed across
SLURM workers, then merges the shards into a single sarm_progress.parquet.
Two subcommands, each a separate SLURM submission:
compute N workers, each computes progress for a subset of episodes
aggregate 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
Usage:
python slurm_compute_rabc.py compute \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--stride 10 --device cpu --workers 50 --partition cpu
python slurm_compute_rabc.py aggregate \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--partition cpu --push-to-hub
"""
import argparse
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
class ComputeProgressShards(PipelineStep):
"""Each worker computes SARM progress for its assigned episodes."""
def __init__(
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
):
super().__init__()
if stride < 1:
raise ValueError(f"stride must be >= 1, got {stride}")
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.stride = stride
self.head_mode = head_mode
self.device = device
self.shard_dir = shard_dir
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.policies.sarm.compute_rabc_weights import (
generate_all_frame_indices,
interpolate_progress,
load_sarm_resources,
)
from lerobot.utils.utils import init_logging
init_logging()
dataset, reward_model, preprocess = load_sarm_resources(
self.repo_id,
self.reward_model_path,
self.device,
)
if hasattr(preprocess, "eval"):
preprocess.eval()
for step in preprocess.steps:
if hasattr(step, "eval"):
step.eval()
image_key = reward_model.config.image_key
state_key = reward_model.config.state_key
frame_gap = reward_model.config.frame_gap
center_idx = reward_model.config.n_obs_steps // 2
dual_mode = reward_model.config.uses_dual_heads
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
compute_dense = self.head_mode in ("dense", "both") and dual_mode
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
if not my_episodes:
logging.info(f"Rank {rank}: no episodes assigned")
return
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
all_rows = []
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
ep = dataset.meta.episodes[ep_idx]
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
task = dataset[ep_start].get("task", "perform the task")
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
if self.stride > 1:
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
if (ep_end - 1) not in compute_indices:
compute_indices.append(ep_end - 1)
compute_indices = sorted(set(compute_indices))
else:
compute_indices = all_ep_indices
frame_results = {}
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
try:
sample = dataset[qi]
batch = {
image_key: sample[image_key],
"task": task,
"index": qi,
"episode_index": ep_idx,
}
if state_key in sample:
batch[state_key] = sample[state_key]
with torch.no_grad():
processed = preprocess(batch)
vf = processed["video_features"].to(self.device)
tf = processed["text_features"].to(self.device)
sf = processed.get("state_features")
if sf is not None:
sf = sf.to(self.device)
lengths = processed.get("lengths")
sparse_val = dense_val = np.nan
if compute_sparse:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="sparse",
)
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
if compute_dense:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="dense",
)
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
frame_results[qi] = (sparse_val, dense_val)
except Exception as e:
logging.warning(f"Failed frame {qi}: {e}")
if not frame_results:
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
continue
# Interpolate to all frames in this episode
computed_idx = np.array(sorted(frame_results.keys()))
all_frame_arr = np.arange(ep_start, ep_end)
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
if self.stride > 1 and len(computed_idx) > 1:
if compute_sparse:
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
if compute_dense:
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
output_frames = all_frame_arr
else:
# Use only successfully computed frames to avoid indexing mismatch on failures
output_frames = computed_idx
for i, fi in enumerate(output_frames):
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
if compute_sparse:
row["progress_sparse"] = float(sparse_vals[i])
if compute_dense:
row["progress_dense"] = float(dense_vals[i])
all_rows.append(row)
if all_rows:
import pandas as pd
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
shard_dir = Path(self.shard_dir)
shard_dir.mkdir(parents=True, exist_ok=True)
out = shard_dir / f"shard_{rank:05d}.parquet"
pq.write_table(table, out)
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
class AggregateProgress(PipelineStep):
"""Merge all shard parquets into final sarm_progress.parquet."""
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
super().__init__()
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.shard_dir = shard_dir
self.push_to_hub = push_to_hub
def run(self, data=None, rank: int = 0, world_size: int = 1):
import datetime
import logging
import os
from pathlib import Path
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.utils import init_logging
init_logging()
if rank != 0:
return
shard_dir = Path(self.shard_dir)
shards = sorted(shard_dir.glob("shard_*.parquet"))
if not shards:
raise FileNotFoundError(f"No shards found in {shard_dir}")
# Log shard modification time range to help detect stale files
mtimes = [os.path.getmtime(s) for s in shards]
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
df = df.sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
out_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out_path)
logging.info(f"Saved {len(df)} rows to {out_path}")
for col in ["progress_sparse", "progress_dense"]:
if col in df.columns:
v = df[col].dropna()
logging.info(
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
)
if self.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = "sarm_progress.parquet"
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(out_path),
path_in_repo=hub_path,
repo_id=self.repo_id,
repo_type="dataset",
)
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
def make_compute_executor(
repo_id,
reward_model_path,
stride,
head_mode,
device,
shard_dir,
logs_dir,
job_name,
slurm,
workers,
partition,
cpus_per_task,
mem_per_cpu,
):
kwargs = {
"pipeline": [
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": workers,
"workers": workers,
"time": "24:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": workers, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def make_aggregate_executor(
repo_id,
reward_model_path,
shard_dir,
logs_dir,
job_name,
slurm,
partition,
cpus_per_task,
mem_per_cpu,
push_to_hub,
):
kwargs = {
"pipeline": [
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": 1,
"workers": 1,
"time": "02:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": 1, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def _add_shared_args(p):
p.add_argument(
"--repo-id",
type=str,
required=True,
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
)
p.add_argument(
"--shard-dir",
type=Path,
default=Path("rabc_shards"),
help="Directory to read/write per-rank parquet shards.",
)
p.add_argument(
"--logs-dir",
type=Path,
default=Path("logs"),
help="Directory for datatrove logs.",
)
p.add_argument(
"--job-name",
type=str,
default=None,
help="SLURM job name (defaults to rabc_<subcommand>).",
)
p.add_argument(
"--slurm",
type=int,
default=1,
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
)
p.add_argument(
"--partition",
type=str,
default=None,
help="SLURM partition to submit to.",
)
p.add_argument(
"--cpus-per-task",
type=int,
default=4,
help="Number of CPUs per SLURM task.",
)
p.add_argument(
"--mem-per-cpu",
type=str,
default="4G",
help="Memory per CPU, e.g. '4G' or '1950M'.",
)
def main():
parser = argparse.ArgumentParser(
description="SLURM-distributed SARM RA-BC annotation pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
sub = parser.add_subparsers(dest="command", required=True)
# compute subcommand
cp = sub.add_parser(
"compute",
help="Distribute progress computation across SLURM workers.",
)
_add_shared_args(cp)
cp.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model.",
)
cp.add_argument(
"--stride",
type=int,
default=1,
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
)
cp.add_argument(
"--head-mode",
type=str,
default="sparse",
choices=["sparse", "dense", "both"],
help="Which reward head(s) to compute.",
)
cp.add_argument(
"--device",
type=str,
default="cpu",
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
)
cp.add_argument(
"--workers",
type=int,
default=50,
help="Number of parallel SLURM tasks (one shard per worker).",
)
# aggregate subcommand
ap = sub.add_parser(
"aggregate",
help="Merge per-rank shards into a single sarm_progress.parquet.",
)
_add_shared_args(ap)
ap.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
)
ap.add_argument(
"--push-to-hub",
action="store_true",
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
)
args = parser.parse_args()
job_name = args.job_name or f"rabc_{args.command}"
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
kwargs["job_name"] = job_name
command = kwargs.pop("command")
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
executor.run()
if __name__ == "__main__":
main()
-351
View File
@@ -1,351 +0,0 @@
#!/usr/bin/env python
"""
Human-in-the-Loop (HIL) Data Collection with Policy Rollout.
Implements the RaC paradigm (Hu et al., 2025) for LeRobot with standard synchronous
inference. For large models with high inference latency, use hil_data_collection_rtc.py.
The workflow:
1. Policy runs autonomously
2. Press SPACE to pause - robot holds position
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/hil_data_collection.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/hil_dataset \
--dataset.single_task="Pick up the cube"
"""
import logging
import time
from dataclasses import dataclass
from pprint import pformat
from typing import Any
import torch
from hil_utils import (
HILDatasetConfig,
init_keyboard_listener,
make_identity_processors,
print_controls,
reset_loop,
teleop_disable_torque,
teleop_smooth_move_to,
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless, predict_action
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
logger = logging.getLogger(__name__)
@dataclass
class HILConfig:
robot: RobotConfig
teleop: TeleoperatorConfig
dataset: HILDatasetConfig
policy: PreTrainedConfig | None = None
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
display_data: bool = True
play_sounds: bool = True
resume: bool = False
device: str = "cuda"
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
@safe_stop_image_writer
def rollout_loop(
robot: Robot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset: LeRobotDataset,
events: dict,
cfg: HILConfig,
):
"""Rollout loop with standard synchronous inference."""
fps = cfg.dataset.fps
device = get_safe_torch_device(cfg.device)
policy.reset()
preprocessor.reset()
postprocessor.reset()
frame_buffer = []
teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
robot_action: dict[str, Any] = {}
action_keys = sorted(robot.action_features.keys())
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(fps)
timestamp = 0
start_t = time.perf_counter()
while timestamp < cfg.dataset.episode_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Transition to paused state
if events["policy_paused"] and not was_paused:
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
interpolator.reset()
# Takeover
if waiting_for_takeover and events["start_next_episode"]:
teleop_disable_torque(teleop)
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
if events["correction_active"]:
robot_action = teleop.get_action()
robot.send_action(robot_action)
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
elif waiting_for_takeover or events["policy_paused"]:
if last_action:
robot.send_action(last_action)
else:
# Policy execution with optional interpolation
if interpolator.needs_new_action():
action_values = predict_action(
observation=obs_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=cfg.dataset.single_task,
robot_type=robot.robot_type,
)
robot_action = make_robot_action(action_values, dataset.features)
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
interpolator.add(action_tensor)
interp_action = interpolator.get()
if interp_action is not None:
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
robot.send_action(robot_action)
last_action = robot_action
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
if cfg.display_data and robot_action:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
if (sleep_time := control_interval - dt) > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t
teleop_disable_torque(teleop)
for frame in frame_buffer:
dataset.add_frame(frame)
@parser.wrap()
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
"""Main HIL data collection function."""
init_logging()
logger.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="hil_collection")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and robot.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
teleop.connect()
listener, events = init_keyboard_listener()
print_controls(rtc=False)
print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}")
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
cfg=cfg,
)
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(robot, teleop, events, cfg.dataset.fps)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
hil_collect()
if __name__ == "__main__":
main()
-513
View File
@@ -1,513 +0,0 @@
#!/usr/bin/env python
"""
Human-in-the-Loop (HIL) Data Collection with Real-Time Chunking (RTC).
Implements the RaC paradigm (Hu et al., 2025) with RTC for large flow-matching models
(Pi0, Pi0.5, SmolVLA) that have high inference latency. RTC generates action chunks
asynchronously in a background thread for smooth robot control.
For fast models (ACT, Diffusion), use hil_data_collection.py instead.
The workflow:
1. Policy runs autonomously with RTC
2. Press SPACE to pause - robot holds position
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout
Keyboard Controls:
SPACE - Pause policy (robot holds position, no recording)
c - Take control (start correction, recording resumes)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage:
python examples/rac/hil_data_collection_rtc.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--policy.path=outputs/train/pi0_policy/checkpoints/last/pretrained_model \
--dataset.repo_id=my_user/hil_rtc_dataset \
--dataset.single_task="Pick up the cube" \
--rtc.execution_horizon=20
"""
import logging
import math
import time
from dataclasses import dataclass, field
from pprint import pformat
from threading import Event, Lock, Thread
from typing import Any
import torch
from hil_utils import (
HILDatasetConfig,
init_keyboard_listener,
make_identity_processors,
print_controls,
reset_loop,
teleop_disable_torque,
teleop_smooth_move_to,
)
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
logger = logging.getLogger(__name__)
@dataclass
class HILRTCConfig:
robot: RobotConfig
teleop: TeleoperatorConfig
dataset: HILDatasetConfig
policy: PreTrainedConfig | None = None
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(enabled=True, execution_horizon=20))
interpolation_multiplier: int = 2 # Control rate multiplier (1=off, 2=2x, 3=3x)
display_data: bool = True
play_sounds: bool = True
resume: bool = False
device: str = "cuda"
use_torch_compile: bool = False # First compile takes minutes, disable for real-time
def __post_init__(self):
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("policy.path is required")
self.rtc.enabled = True
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
class ThreadSafeRobot:
"""Thread-safe wrapper for robot operations."""
def __init__(self, robot: Robot):
self._robot = robot
self._lock = Lock()
def get_observation(self) -> dict[str, Any]:
with self._lock:
return self._robot.get_observation()
def send_action(self, action: dict) -> None:
with self._lock:
self._robot.send_action(action)
@property
def observation_features(self) -> dict:
return self._robot.observation_features
@property
def action_features(self) -> dict:
return self._robot.action_features
@property
def name(self) -> str:
return self._robot.name
@property
def robot_type(self) -> str:
return self._robot.robot_type
@property
def cameras(self):
return getattr(self._robot, "cameras", {})
def rtc_inference_thread(
policy: PreTrainedPolicy,
obs_holder: dict,
hw_features: dict,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
queue_holder: dict,
shutdown_event: Event,
policy_active: Event,
cfg: HILRTCConfig,
):
"""Background thread for RTC action chunk generation."""
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / cfg.dataset.fps
threshold = 30
while not shutdown_event.is_set():
if not policy_active.is_set():
time.sleep(0.01)
continue
queue = queue_holder.get("queue")
obs = obs_holder.get("obs")
if queue is None or obs is None:
time.sleep(0.01)
continue
if queue.qsize() <= threshold:
try:
current_time = time.perf_counter()
idx_before = queue.get_action_index()
prev_actions = queue.get_left_over()
latency = latency_tracker.max()
delay = math.ceil(latency / time_per_chunk) if latency else 0
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
for name in obs_batch:
obs_batch[name] = torch.from_numpy(obs_batch[name])
if "image" in name:
obs_batch[name] = obs_batch[name].float() / 255
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
obs_batch[name] = obs_batch[name].unsqueeze(0).to(cfg.device)
obs_batch["task"] = [cfg.dataset.single_task]
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
preprocessed = preprocessor(obs_batch)
actions = policy.predict_action_chunk(
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
)
original = actions.squeeze(0).clone()
processed = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
queue.merge(original, processed, new_delay, idx_before)
logger.debug(f"[RTC] Inference latency={new_latency:.2f}s, queue={queue.qsize()}")
except Exception as e:
logger.error(f"[RTC] Error: {e}")
time.sleep(0.5)
else:
time.sleep(0.01)
@safe_stop_image_writer
def rollout_loop(
robot: ThreadSafeRobot,
teleop: Teleoperator,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset: LeRobotDataset,
events: dict,
cfg: HILRTCConfig,
queue_holder: dict,
obs_holder: dict,
policy_active: Event,
hw_features: dict,
):
"""Rollout loop with RTC for asynchronous inference."""
fps = cfg.dataset.fps
policy.reset()
preprocessor.reset()
postprocessor.reset()
frame_buffer = []
teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(fps)
robot_action: dict[str, Any] = {}
timestamp = 0
start_t = time.perf_counter()
while timestamp < cfg.dataset.episode_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
break
# Transition to paused state
if events["policy_paused"] and not was_paused:
policy_active.clear()
obs = robot.get_observation()
robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
events["start_next_episode"] = False
waiting_for_takeover = True
was_paused = True
interpolator.reset()
# Takeover
if waiting_for_takeover and events["start_next_episode"]:
teleop_disable_torque(teleop)
events["start_next_episode"] = False
events["correction_active"] = True
waiting_for_takeover = False
obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
obs_holder["obs"] = obs_filtered
if events["correction_active"]:
robot_action = teleop.get_action()
robot.send_action(robot_action)
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
elif waiting_for_takeover or events["policy_paused"]:
if last_action:
robot.send_action(last_action)
else:
# Policy execution with RTC
if not policy_active.is_set():
policy_active.set()
queue = queue_holder["queue"]
if interpolator.needs_new_action():
new_action = queue.get() if queue else None
if new_action is not None:
interpolator.add(new_action.cpu())
action_tensor = interpolator.get()
if action_tensor is not None:
robot_action = {
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
}
robot.send_action(robot_action)
last_action = robot_action
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
if cfg.display_data and robot_action:
log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start
if (sleep_time := control_interval - dt) > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t
policy_active.clear()
teleop_disable_torque(teleop)
for frame in frame_buffer:
dataset.add_frame(frame)
@parser.wrap()
def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
"""Main HIL data collection function with RTC."""
init_logging()
logger.info(pformat(cfg.__dict__))
if cfg.display_data:
init_rerun(session_name="hil_rtc_collection")
robot_raw = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, obs_proc = make_identity_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_proc,
initial_features=create_initial_features(action=robot_raw.action_features),
use_videos=cfg.dataset.video,
),
aggregate_pipeline_dataset_features(
pipeline=obs_proc,
initial_features=create_initial_features(observation=robot_raw.observation_features),
use_videos=cfg.dataset.video,
),
)
dataset = None
listener = None
shutdown_event = Event()
policy_active = Event()
rtc_thread = None
try:
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
)
else:
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot_raw.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
# Load policy with RTC
policy_class = get_policy_class(cfg.policy.type)
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if hasattr(policy_config, "compile_model"):
policy_config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
policy.config.rtc_config = cfg.rtc
if hasattr(policy, "init_rtc_processor"):
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot_raw.connect()
robot = ThreadSafeRobot(robot_raw)
teleop.connect()
listener, events = init_keyboard_listener()
queue_holder = {"queue": ActionQueue(cfg.rtc)}
obs_holder = {"obs": None, "robot_type": robot.robot_type}
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
rtc_thread = Thread(
target=rtc_inference_thread,
args=(
policy,
obs_holder,
hw_features,
preprocessor,
postprocessor,
queue_holder,
shutdown_event,
policy_active,
cfg,
),
daemon=True,
)
rtc_thread.start()
print_controls(rtc=True)
print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}")
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
with VideoEncodingManager(dataset):
recorded = 0
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
queue_holder["queue"] = ActionQueue(cfg.rtc)
rollout_loop(
robot=robot,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
events=events,
cfg=cfg,
queue_holder=queue_holder,
obs_holder=obs_holder,
policy_active=policy_active,
hw_features=hw_features,
)
if events["rerecord_episode"]:
log_say("Re-recording", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop(robot, teleop, events, cfg.dataset.fps)
finally:
log_say("Stop recording", cfg.play_sounds, blocking=True)
shutdown_event.set()
policy_active.clear()
if rtc_thread and rtc_thread.is_alive():
rtc_thread.join(timeout=2.0)
if dataset:
dataset.finalize()
if robot_raw.is_connected:
robot_raw.disconnect()
if teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
def main():
from lerobot.utils.import_utils import register_third_party_plugins
register_third_party_plugins()
hil_rtc_collect()
if __name__ == "__main__":
main()
-206
View File
@@ -1,206 +0,0 @@
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return hasattr(teleop, "bus") and hasattr(teleop.bus, "disable_torque")
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if teleop_has_motor_control(teleop):
teleop.bus.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if teleop_has_motor_control(teleop):
teleop.bus.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = int(duration_s * fps)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.bus.sync_write("Goal_Position", {k.replace(".pos", ""): v for k, v in interp.items()})
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
print("\n[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
print("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
print("\n[HIL] ⏸ PAUSED - Press 'c' to take control")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
print("\n[HIL] ▶ Taking control...")
events["start_next_episode"] = True
elif key == keyboard.Key.right:
print("[HIL] → End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("[HIL] ← Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
print("\n" + "=" * 60)
print(" [HIL] RESET")
print("=" * 60)
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
print(" Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
print(" Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
print("\n" + "=" * 60)
print(" Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else ""))
print("=" * 60)
print()
print(" Controls:")
print(" SPACE - Pause policy")
print(" c - Take control")
print(" → - End episode")
print(" ESC - Stop and push to hub")
print("=" * 60 + "\n")
+8 -9
View File
@@ -83,7 +83,9 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
@@ -149,7 +151,6 @@ class RTCDemoConfig(HubMixin):
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
@@ -350,22 +351,20 @@ def actor_control(
logger.info("[ACTOR] Starting actor thread")
action_count = 0
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
action_interval = 1.0 / cfg.fps
while not shutdown_event.is_set():
start_time = time.perf_counter()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
# Try to get an action from the queue with timeout
action = action_queue.get()
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
+49 -116
View File
@@ -25,11 +25,11 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.4"
version = "0.4.5"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
requires-python = ">=3.10"
requires-python = ">=3.12"
authors = [
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
@@ -50,7 +50,8 @@ classifiers = [
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Build Tools",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
@@ -61,26 +62,28 @@ dependencies = [
# Hugging Face dependencies
"datasets>=4.0.0,<5.0.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"huggingface-hub>=1.0.0,<2.0.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies
"numpy>=2.0.0,<2.3.0", # TODO: upper bound imposed by opencv-python-headless
"setuptools>=71.0.0,<81.0.0",
"cmake>=3.29.0.1,<4.2.0",
"packaging>=24.2,<26.0",
"torch>=2.2.1,<2.11.0",
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0,<0.26.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pynput>=1.7.8,<1.9.0",
"pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0",
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
"draccus==0.10.0", # TODO: Remove ==
"draccus==0.10.0", # TODO: Relax version constraint
"gymnasium>=1.1.1,<2.0.0",
"rerun-sdk>=0.24.0,<0.27.0",
@@ -95,10 +98,14 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -116,7 +123,7 @@ unitree_g1 = [
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"matplotlib>=3.9.0,<4.0.0",
"lerobot[matplotlib-dep]",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
@@ -125,21 +132,21 @@ intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
# Policies
wallx = [
"transformers==4.49.0",
"peft==0.17.1",
"scipy==1.15.3",
"torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11"
"lerobot[transformers-dep]",
"lerobot[peft]",
"lerobot[scipy-dep]",
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
"lerobot[peft]",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
@@ -148,13 +155,13 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -162,13 +169,18 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
metaworld = ["metaworld==3.0.0"]
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
all = [
# Resolver hint: scipy is pulled in transitively via lerobot[scipy-dep] through
# multiple extras below (aloha, metaworld, pi, wallx, phone). Listing it explicitly
# helps pip's resolver converge by constraining scipy early, before it encounters
# the loose scipy requirements from transitive deps like dm-control and metaworld.
"scipy>=1.14.0,<2.0.0",
"lerobot[dynamixel]",
"lerobot[gamepad]",
"lerobot[hopejr]",
@@ -176,8 +188,8 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
# "lerobot[wallx]",
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
@@ -189,7 +201,7 @@ all = [
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[phone]",
"lerobot[libero]",
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[peft]",
@@ -214,11 +226,14 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
[tool.ruff]
target-version = "py310"
target-version = "py312"
line-length = 110
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
@@ -310,7 +325,7 @@ default.extend-ignore-identifiers-re = [
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
[tool.mypy]
python_version = "3.10"
python_version = "3.12"
ignore_missing_imports = true
follow_imports = "skip"
# warn_return_any = true
@@ -394,85 +409,3 @@ ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"
# ignore_errors = false
[tool.uv]
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
conflicts = [
[
{ extra = "wallx" },
{ extra = "transformers-dep" },
],
[
{ extra = "wallx" },
{ extra = "pi" },
],
[
{ extra = "wallx" },
{ extra = "smolvla" },
],
[
{ extra = "wallx" },
{ extra = "groot" },
],
[
{ extra = "wallx" },
{ extra = "xvla" },
],
[
{ extra = "wallx" },
{ extra = "sarm" },
],
[
{ extra = "wallx" },
{ extra = "hilserl" },
],
[
{ extra = "wallx" },
{ extra = "libero" },
],
[
{ extra = "wallx" },
{ extra = "peft" },
],
[
{ extra = "wallx" },
{ extra = "all" },
],
# pi uses custom branch which conflicts with transformers-dep
[
{ extra = "pi" },
{ extra = "transformers-dep" },
],
[
{ extra = "pi" },
{ extra = "smolvla" },
],
[
{ extra = "pi" },
{ extra = "groot" },
],
[
{ extra = "pi" },
{ extra = "xvla" },
],
[
{ extra = "pi" },
{ extra = "sarm" },
],
[
{ extra = "pi" },
{ extra = "hilserl" },
],
[
{ extra = "pi" },
{ extra = "libero" },
],
[
{ extra = "pi" },
{ extra = "peft" },
],
[
{ extra = "pi" },
{ extra = "all" },
],
]
+7 -10
View File
@@ -49,23 +49,18 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so_follower,
koch_follower,
from lerobot.robots import (
RobotConfig, # noqa: F401
make_robot_from_config,
omx_follower,
so_follower,
)
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
from lerobot.utils.import_utils import register_third_party_plugins
from .configs import RobotClientConfig
from .constants import SUPPORTED_ROBOTS
from .helpers import (
Action,
FPSTracker,
@@ -485,8 +480,9 @@ class RobotClient:
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
# TODO: Assert if checking robot support is still needed with the plugin system
# if cfg.robot.type not in SUPPORTED_ROBOTS:
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
@@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig):
if __name__ == "__main__":
register_third_party_plugins()
async_client() # run the client
+1 -1
View File
@@ -27,7 +27,7 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
+3 -1
View File
@@ -289,7 +289,9 @@ def aggregate_datasets(
logging.info("Find all tasks")
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
dst_meta.tasks = pd.DataFrame(
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
)
meta_idx = {"chunk": 0, "file": 0}
data_idx = {"chunk": 0, "file": 0}
+7
View File
@@ -7,6 +7,13 @@
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
{% if repo_id is defined and repo_id %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
## Dataset Description
{{ dataset_description | default("", true) }}
+41 -33
View File
@@ -89,8 +89,8 @@ def delete_episodes(
Args:
dataset: The source LeRobotDataset.
episode_indices: List of episode indices to delete.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
"""
if not episode_indices:
raise ValueError("No episodes to delete")
@@ -152,7 +152,7 @@ def split_dataset(
dataset: The source LeRobotDataset to split.
splits: Either a dict mapping split names to episode indices, or a dict mapping
split names to fractions (must sum to <= 1.0).
output_dir: Base directory for output datasets. If None, uses default location.
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
Examples:
Split by specific episodes
@@ -243,8 +243,8 @@ def merge_datasets(
Args:
datasets: List of LeRobotDatasets to merge.
output_repo_id: Repository ID for the merged dataset.
output_dir: Directory to save the merged dataset. If None, uses default location.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -288,8 +288,8 @@ def modify_features(
dataset: The source LeRobotDataset.
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
remove_features: Optional feature name(s) to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with features modified.
@@ -390,8 +390,8 @@ def add_features(
Args:
dataset: The source LeRobotDataset.
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with all features added.
@@ -427,8 +427,8 @@ def remove_feature(
Args:
dataset: The source LeRobotDataset.
feature_names: Name(s) of features to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with features removed.
@@ -567,20 +567,22 @@ def _copy_and_reindex_data(
def _keep_episodes_from_video_with_av(
input_path: Path,
output_path: Path,
episodes_to_keep: list[tuple[float, float]],
episodes_to_keep: list[tuple[int, int]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> None:
"""Keep only specified episodes from a video file using PyAV.
This function decodes frames from specified time ranges and re-encodes them with
This function decodes frames from specified frame ranges and re-encodes them with
properly reset timestamps to ensure monotonic progression.
Args:
input_path: Source video file path.
output_path: Destination video file path.
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
@@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av(
# Create set of (start, end) ranges for fast lookup.
# Convert to a sorted list for efficient checking.
time_ranges = sorted(episodes_to_keep)
frame_ranges = sorted(episodes_to_keep)
# Track frame index for setting PTS and current range being processed.
src_frame_count = 0
frame_count = 0
range_idx = 0
@@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av(
if frame is None:
continue
# Get frame timestamp.
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
# Check if frame is in any of our desired time ranges.
# Check if frame is in any of our desired frame ranges.
# Skip ranges that have already passed.
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
range_idx += 1
# If we've passed all ranges, stop processing.
if range_idx >= len(time_ranges):
if range_idx >= len(frame_ranges):
break
# Check if frame is in current range.
start_ts, end_ts = time_ranges[range_idx]
if frame_time < start_ts:
start_frame = frame_ranges[range_idx][0]
if src_frame_count < start_frame:
src_frame_count += 1
continue
# Frame is in range - create a new frame with reset timestamps.
@@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av(
for pkt in v_out.encode(new_frame):
out.mux(pkt)
src_frame_count += 1
frame_count += 1
# Flush encoder.
@@ -749,15 +752,17 @@ def _copy_and_reindex_videos(
f"videos/{video_key}/to_timestamp"
]
else:
# Build list of time ranges to keep, in sorted order.
# Build list of frame ranges to keep, in sorted order.
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
episodes_to_keep_ranges: list[tuple[float, float]] = []
episodes_to_keep_ranges: list[tuple[int, int]] = []
for old_idx in sorted_keep_episodes:
src_ep = src_dataset.meta.episodes[old_idx]
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
episodes_to_keep_ranges.append((from_ts, to_ts))
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
assert src_ep["length"] == to_frame - from_frame, (
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
)
episodes_to_keep_ranges.append((from_frame, to_frame))
# Use PyAV filters to efficiently re-encode only the desired segments.
assert src_dataset.meta.video_path is not None
@@ -1470,7 +1475,9 @@ def modify_tasks(
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
new_task_df = pd.DataFrame(
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
@@ -1524,7 +1531,7 @@ def modify_tasks(
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
output_dir: Path | None = None,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
@@ -1543,8 +1550,8 @@ def convert_image_to_video_dataset(
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
@@ -1595,6 +1602,7 @@ def convert_image_to_video_dataset(
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
+14 -13
View File
@@ -314,7 +314,7 @@ class LeRobotDatasetMetadata:
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for the README).
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory where the dataset will be downloaded and
stored. If set, all dataset files will be stored directly under this path. If not set, the
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
HF_LEROBOT_HOME environment variable).
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Check if cached dataset contains all requested episodes
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hub_api.upload_folder(**upload_kwargs)
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
@@ -1771,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
if extra_keys:
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
+3 -4
View File
@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any, Generic, TypeVar
from typing import Any
import datasets
import numpy as np
@@ -78,8 +78,6 @@ DEFAULT_FEATURES = {
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
T = TypeVar("T")
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
metadata = pq.read_metadata(parquet_path)
@@ -341,6 +339,7 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
tasks.index.name = "task"
return tasks
@@ -1233,7 +1232,7 @@ class LookAheadError(Exception):
pass
class Backtrackable(Generic[T]):
class Backtrackable[T]:
"""
Wrap any iterator/iterable so you can step back up to `history` items
and look ahead up to `lookahead` items.
@@ -36,8 +36,11 @@ Convert a local dataset (works in place):
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht \
--root=/path/to/local/dataset/directory
--root=/path/to/local/dataset/directory \
--push-to-hub=false
N.B. Path semantics (v2): --root is the exact dataset folder containing
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
```
"""
@@ -105,7 +108,7 @@ episodes.jsonl
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
NEW
meta/episodes/chunk-000/episodes_000.parquet
meta/episodes/chunk-000/file_000.parquet
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
-------------------------
OLD
@@ -113,15 +116,16 @@ tasks.jsonl
{"task_index": 1, "task": "Put the blue block in the green bowl"}
NEW
meta/tasks/chunk-000/file_000.parquet
meta/tasks.parquet
task_index | task
-------------------------
OLD
episodes_stats.jsonl
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
NEW
meta/episodes_stats/chunk-000/file_000.parquet
episode_index | mean | std | min | max
meta/episodes/chunk-000/file_000.parquet
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
-------------------------
UPDATE
meta/info.json
@@ -170,7 +174,7 @@ def convert_tasks(root, new_root):
tasks, _ = legacy_load_tasks(root)
task_indices = tasks.keys()
task_strings = tasks.values()
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
write_tasks(df_tasks, new_root)
@@ -201,7 +205,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
image_keys = get_image_keys(root)
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
@@ -211,9 +214,23 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
logging.info(f"Converting data files from {len(ep_paths)} episodes")
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
# Check if we need to start a new file BEFORE creating metadata
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
# Write the accumulated data files
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Move to next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Reset for the next file
size_in_mb = 0
paths_to_cat = []
# Now create metadata with correct chunk/file indices
ep_metadata = {
"episode_index": ep_idx,
"data/chunk_index": chunk_idx,
@@ -224,20 +241,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
size_in_mb += ep_size_in_mb
num_frames += ep_num_frames
episodes_metadata.append(ep_metadata)
ep_idx += 1
if size_in_mb < data_file_size_in_mb:
paths_to_cat.append(ep_path)
continue
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Reset for the next file
size_in_mb = ep_size_in_mb
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
paths_to_cat.append(ep_path)
# Write remaining data if any
if paths_to_cat:
@@ -469,7 +473,7 @@ def convert_dataset(
# Set root based on whether local dataset path is provided
use_local_dataset = False
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
if root.exists():
validate_local_dataset_version(root)
use_local_dataset = True
@@ -553,7 +557,7 @@ if __name__ == "__main__":
"--root",
type=str,
default=None,
help="Local directory to use for downloading/writing the dataset.",
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
)
parser.add_argument(
"--push-to-hub",
+26 -20
View File
@@ -227,16 +227,17 @@ def decode_video_frames_torchvision(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -248,7 +249,11 @@ def decode_video_frames_torchvision(
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
if len(timestamps) != len(closest_frames):
raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})"
)
return closest_frames
@@ -353,15 +358,16 @@ def decode_video_frames_torchcodec(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
+4 -4
View File
@@ -29,7 +29,7 @@ from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pprint import pformat
from typing import Protocol, TypeAlias
from typing import Protocol
import serial
from deepdiff import DeepDiff
@@ -38,8 +38,8 @@ from tqdm import tqdm
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
type NameOrID = str | int
type Value = int | float
logger = logging.getLogger(__name__)
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
# Backward compatibility alias
MotorsBus: TypeAlias = SerialMotorsBus
MotorsBus = SerialMotorsBus
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
backbone. If None, no resizing is done and the original image resolution is used.
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
this is computed automatically. Can also be set directly for legacy configs that use
crop-only (without resize). If None and no derivation applies, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center
crop in eval mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
resize_shape: tuple[int, int] | None = None
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
@@ -139,6 +147,10 @@ class DiffusionConfig(PreTrainedConfig):
# Inference
num_inference_steps: int | None = None
# Optimization
compile_model: bool = False
compile_mode: str = "reduce-overhead"
# Loss computation
do_mask_loss_for_padding: bool = False
@@ -171,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}."
)
if self.resize_shape is not None and (
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
):
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
if not (0 < self.crop_ratio <= 1.0):
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
if self.resize_shape is not None:
if self.crop_ratio < 1.0:
self.crop_shape = (
int(self.resize_shape[0] * self.crop_ratio),
int(self.resize_shape[1] * self.crop_ratio),
)
else:
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
self.crop_shape = None
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
@@ -198,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
if self.resize_shape is None and self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for `{key}`."
)
# Check that all input images have the same shape.
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
for key in self.config.image_features:
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
@@ -446,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None
crop_shape = config.crop_shape
if crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -477,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
if config.crop_shape is not None:
dummy_shape_h_w = config.crop_shape
elif config.resize_shape is not None:
dummy_shape_h_w = config.resize_shape
else:
dummy_shape_h_w = images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
@@ -499,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
# Preprocess: resize if configured, then crop if configured.
if self.resize is not None:
x = self.resize(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
+1 -2
View File
@@ -18,10 +18,9 @@ from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, Unpack
import torch
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
@@ -4,17 +4,16 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import annotations
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from typing import Optional
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
ImagesKwargs,
group_images_by_shape,
reorder_images,
)
@@ -77,7 +76,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
@@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _resize_for_patching(
self,
image: "torch.Tensor",
image: torch.Tensor,
target_resolution: tuple,
interpolation: "F.InterpolationMode",
interpolation: F.InterpolationMode,
input_data_format: ChannelDimension,
) -> "torch.Tensor":
) -> torch.Tensor:
"""
Resizes an image to a target resolution while maintaining aspect ratio.
@@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
return best_ratio
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
) -> torch.Tensor:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
@@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _get_image_patches(
self,
image: "torch.Tensor",
image: torch.Tensor,
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: "F.InterpolationMode",
interpolation: F.InterpolationMode,
pad_during_tiling: bool,
) -> list["torch.Tensor"]:
) -> list[torch.Tensor]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
@@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _pad_for_batching(
self,
pixel_values: list["torch.Tensor"],
) -> list["torch.Tensor"]:
pixel_values: list[torch.Tensor],
) -> list[torch.Tensor]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
@@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _preprocess(
self,
images: list["torch.Tensor"],
images: list[torch.Tensor],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: Optional["F.InterpolationMode"],
interpolation: F.InterpolationMode | None,
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
+68 -54
View File
@@ -15,16 +15,16 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -32,13 +32,21 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
@@ -191,7 +199,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
resized_images = resized_images.clamp(0.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -202,7 +210,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
constant_value = 0 if images.dtype == torch.uint8 else 0.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -221,14 +229,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.language_model, gemma_expert.model]
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -254,10 +262,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -265,7 +273,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -277,15 +285,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -358,7 +366,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -366,7 +374,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.torch_dtype = "float32"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -377,13 +385,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
torch_dtype="float32",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -398,10 +406,11 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -413,8 +422,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -424,15 +433,23 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
self.paligemma.model.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
@@ -446,7 +463,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.language_model.forward(
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -470,7 +487,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.language_model, self.gemma_expert.model]
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -510,7 +527,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -576,29 +593,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
@@ -760,7 +767,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -834,7 +841,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -908,6 +915,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -997,14 +1005,12 @@ class PI0Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Now manually load and remap the state dict
# Load state dict (expects keys with "model." prefix)
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -1012,7 +1018,7 @@ class PI0Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -1025,7 +1031,7 @@ class PI0Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1070,7 +1076,7 @@ class PI0Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not remap state dict keys: {e}")
print(f"Warning: Could not load state dict: {e}")
return model
@@ -1120,6 +1126,14 @@ class PI0Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
+71 -60
View File
@@ -15,16 +15,16 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -32,14 +32,20 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -92,10 +98,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
return dist.sample((bsize,)).to(device)
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
@@ -189,7 +196,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
resized_images = resized_images.clamp(0.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -200,7 +207,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
constant_value = 0 if images.dtype == torch.uint8 else 0.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -219,14 +226,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.language_model, gemma_expert.model]
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -252,10 +259,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -263,7 +270,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -275,15 +282,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -356,7 +363,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -364,7 +371,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.torch_dtype = "float32"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -375,13 +382,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
torch_dtype="float32",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -396,10 +403,11 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
params_to_keep_float32 = [
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -411,8 +419,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -422,15 +430,23 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
self.paligemma.model.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
@@ -444,7 +460,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.language_model.forward(
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -468,7 +484,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.language_model, self.gemma_expert.model]
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -508,7 +524,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -573,29 +589,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
@@ -737,7 +743,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -808,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -880,6 +886,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -969,14 +976,12 @@ class PI05Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Now manually load and remap the state dict
# Load state dict (expects keys with "model." prefix)
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -984,7 +989,7 @@ class PI05Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -997,7 +1002,7 @@ class PI05Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1009,8 +1014,6 @@ class PI05Policy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -1044,7 +1047,7 @@ class PI05Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not remap state dict keys: {e}")
print(f"Warning: Could not load state dict: {e}")
return model
@@ -1098,6 +1101,14 @@ class PI05Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
@@ -23,7 +23,6 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -68,9 +67,6 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
tokenizer_max_length: int = 200 # see openpi `__post_init__`
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
action_tokenizer_name: str = "physical-intelligence/fast"
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
temperature: float = 0.0
max_decoding_steps: int = 256
fast_skip_tokens: int = 128
@@ -19,13 +19,12 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _scipy_available, _transformers_available
@@ -38,11 +37,16 @@ else:
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaModel,
)
else:
CONFIG_MAPPING = None
PaliGemmaForConditionalGeneration = None
AutoTokenizer = None
PiGemmaModel = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
@@ -121,7 +125,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(-1.0, 1.0)
resized_images = resized_images.clamp(0.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -132,7 +136,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else -1.0
constant_value = 0 if images.dtype == torch.uint8 else 0.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -206,16 +210,22 @@ class PI0FastPaliGemma(nn.Module):
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.torch_dtype = "float32"
vlm_config_hf.vision_config.dtype = "float32"
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
if use_adarms[0]:
text_config = self.paligemma.config.text_config
self.paligemma.model.language_model = PiGemmaModel(text_config)
self.to_bfloat16_for_selected_params(precision)
@@ -228,10 +238,11 @@ class PI0FastPaliGemma(nn.Module):
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"vision_tower",
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -242,10 +253,18 @@ class PI0FastPaliGemma(nn.Module):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
return self.paligemma.model.get_image_features(image)
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
@@ -259,7 +278,7 @@ class PI0FastPaliGemma(nn.Module):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.language_model.forward(
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -306,24 +325,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
@@ -332,8 +341,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
# Call the proper gradient_checkpointing_disable() method
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
def _apply_checkpoint(self, func, *args, **kwargs):
@@ -523,7 +532,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Convert embeddings to bfloat16 if needed
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -616,7 +625,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -714,7 +723,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Ensure correct precision (bfloat16/float32)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -897,14 +906,12 @@ class PI0FastPolicy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Now manually load and remap the state dict
# Load state dict (expects keys with "model." prefix)
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -912,7 +919,7 @@ class PI0FastPolicy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -925,8 +932,9 @@ class PI0FastPolicy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
@@ -936,8 +944,6 @@ class PI0FastPolicy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -971,7 +977,7 @@ class PI0FastPolicy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not remap state dict keys: {e}")
print(f"Warning: Could not load state dict: {e}")
return model
@@ -23,7 +23,6 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
@@ -69,9 +68,6 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
+363
View File
@@ -0,0 +1,363 @@
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch import nn
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaConfig,
GemmaForCausalLM,
GemmaMLP,
GemmaModel,
)
from transformers.models.paligemma.modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaModel,
)
else:
GemmaAttention = None
GemmaConfig = None
GemmaForCausalLM = None
GemmaMLP = None
GemmaModel = None
PaliGemmaModel = None
PaliGemmaForConditionalGeneration = None
DynamicCache = None
GradientCheckpointingLayer = None
BaseModelOutputWithPast = None
create_causal_mask = None
def _gated_residual(
x: torch.Tensor | None,
y: torch.Tensor | None,
gate: torch.Tensor | None,
) -> torch.Tensor | None:
"""Gated residual: x + y when gate is None, else x + y * gate."""
if x is None and y is None:
return None
if x is None or y is None:
return x if x is not None else y
if gate is None:
return x + y
return x + y * gate
def layernorm_forward(
layernorm: nn.Module,
x: torch.Tensor,
cond: torch.Tensor | None = None,
):
"""
call layernorm and return hidden states and gate
if cond is not None, use conditional norm
otherwise, use normal gemma norm
"""
if cond is not None:
return layernorm(x, cond=cond)
else:
return layernorm(x)
class PiGemmaRMSNorm(nn.Module):
"""
Adaptive RMSNorm for PI Gemma (AdaRMS).
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
"""
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
self.eps = eps
self.dim = dim
self.cond_dim = cond_dim
if cond_dim is not None:
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
nn.init.zeros_(self.dense.weight)
else:
self.weight = nn.Parameter(torch.zeros(dim))
self.dense = None
def _norm(self, x):
# Compute variance in float32 (like the source implementation)
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
# Compute normalization in float32
normed_inputs = x * torch.rsqrt(var + self.eps)
return normed_inputs
def forward(
self,
x: torch.Tensor,
cond: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
dtype = x.dtype
normed = self._norm(x)
if cond is None or self.dense is None:
normed = normed * (1.0 + self.weight.float())
return normed.type_as(x), None
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
modulation = self.dense(cond)
if len(x.shape) == 3:
modulation = modulation.unsqueeze(1)
scale, shift, gate = modulation.chunk(3, dim=-1)
normed = normed * (1 + scale.float()) + shift.float()
return normed.to(dtype), gate.to(dtype)
def extra_repr(self) -> str:
if self.dense is not None:
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
return f"dim={self.dim}, eps={self.eps}"
def _get_pi_gemma_decoder_layer_base():
"""base for PiGemmaDecoderLayer"""
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
cond_dim = (
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
)
self.input_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
self.post_attention_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
hidden_states, _ = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = _gated_residual(residual, hidden_states, gate)
residual = hidden_states
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
hidden_states = self.mlp(hidden_states)
hidden_states = _gated_residual(residual, hidden_states, gate)
return hidden_states
return _PiGemmaDecoderLayerBase
class PiGemmaModel(GemmaModel): # type: ignore[misc]
"""
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
# if not getattr(config, "use_adarms", False):
# return
cond_dim = getattr(config, "adarms_cond_dim", None)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutputWithPast:
"""
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
import logging
logging.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = layer_outputs
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states, _ = self.norm(hidden_states, adarms_cond)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
"""
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = PiGemmaModel(config)
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
def __init__(self, config):
super().__init__(config)
self.language_model = PiGemmaModel(config.text_config)
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
def __init__(self, config):
super().__init__(config)
self.model = PaliGemmaModelWithPiGemma(config)
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model
__all__ = [
"PiGemmaModel",
"PiGemmaForCausalLM",
"PiGemmaRMSNorm",
"_gated_residual",
"layernorm_forward",
"PaliGemmaModelWithPiGemma",
"PaliGemmaForConditionalGenerationWithPiGemma",
]
+1 -2
View File
@@ -19,7 +19,7 @@ import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypedDict, TypeVar
from typing import TypedDict, TypeVar, Unpack
import packaging
import safetensors
@@ -28,7 +28,6 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@@ -1,117 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
self._buffer = [action]
self._prev = action
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
+1 -3
View File
@@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
if self.dataset_meta is not None:
episodes_df = None
if self.sparse_subtask_names != ["task"]:
episodes_df = self.dataset_meta.episodes.to_pandas()
episodes_df = self.dataset_meta.episodes.to_pandas()
# Generate sparse targets
if self.sparse_temporal_proportions is not None:
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
def __post_init__(self):
super().__post_init__()
@@ -54,12 +54,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
from typing import TypedDict
from typing import TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
@@ -593,6 +592,12 @@ class VLAFlowMatching(nn.Module):
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
# Compile model if requested
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map=device,
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
@@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig):
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
# Tokenizer settings
action_tokenizer_path: str | None = "physical-intelligence/fast"
action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer"
# Action prediction mode: "diffusion" or "fast"
prediction_mode: str = "diffusion"
+12 -2
View File
@@ -261,10 +261,15 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
and optional LoRA fine-tuning support.
"""
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
def init_weights(self):
if getattr(self.model, "language_model", None) is not None:
return
super().init_weights()
@classmethod
def from_pretrained(
cls,
@@ -312,6 +317,11 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
processor.action_processor = action_tokenizer
else:
action_tokenizer = None
# add pad_token_id to config
config.pad_token_id = processor.tokenizer.pad_token_id
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
# Initialize model with configuration and processor
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
@@ -331,7 +341,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
use_auth_token=kwargs.get("use_auth_token"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -21,6 +21,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@@ -38,6 +39,7 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.initializer_range = initializer_range
class Qwen2_5_VLConfig(PretrainedConfig):
@@ -11,7 +11,6 @@ from transformers.activations import ACT2FN
from transformers.cache_utils import (
Cache,
DynamicCache,
SlidingWindowCache,
StaticCache,
)
from transformers.generation import GenerationMixin
@@ -31,6 +30,15 @@ from transformers.utils import (
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
# always return False (which is the correct behavior when no sliding window cache is in use).
class _SlidingWindowCachePlaceholder:
pass
SlidingWindowCache = _SlidingWindowCachePlaceholder
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
@@ -594,19 +602,40 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return hidden_states
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
"""
compute default rope parameters for Qwen2_5_VL
"""
base = config.text_config.rope_parameters["rope_theta"]
dim = config.hidden_size // config.num_attention_heads
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, 1.0
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
self.rope_type = config.rope_parameters.get("rope_type", "default")
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
if self.rope_type == "default":
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
self.rope_kwargs = {}
else:
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
self.rope_kwargs = {}
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -1567,7 +1596,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
+2 -2
View File
@@ -144,7 +144,7 @@ def preprocesser_call(
"""
# Process image inputs
if images is not None and len(images) > 0:
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
@@ -152,7 +152,7 @@ def preprocesser_call(
# Process video inputs
if videos is not None:
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
video_grid_thw = videos_inputs["video_grid_thw"]
else:
videos_inputs = {}
@@ -276,6 +276,8 @@ class Florence2LanguageConfig(PretrainedConfig):
)
# ensure backward compatibility for BART CNN models
if not hasattr(self, "forced_bos_token_id"):
self.forced_bos_token_id = None
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
@@ -1951,7 +1951,10 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = {
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
def __init__(self, config: Florence2LanguageConfig):
super().__init__(config)
@@ -2076,7 +2079,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = {
"model.encoder.embed_tokens.weight": "model.shared.weight",
"model.decoder.embed_tokens.weight": "model.shared.weight",
}
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
def __init__(self, config: Florence2LanguageConfig):
@@ -2436,11 +2442,10 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
FLORENCE2_START_DOCSTRING,
)
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
_tied_weights_keys = [
"language_model.encoder.embed_tokens.weight",
"language_model.decoder.embed_tokens.weight",
"language_model.lm_head.weight",
]
_tied_weights_keys = {
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
}
def __init__(self, config: Florence2Config):
super().__init__(config)
+5 -5
View File
@@ -17,7 +17,7 @@
from __future__ import annotations
from enum import Enum
from typing import Any, TypeAlias, TypedDict
from typing import Any, TypedDict
import numpy as np
import torch
@@ -36,10 +36,10 @@ class TransitionKey(str, Enum):
COMPLEMENTARY_DATA = "complementary_data"
PolicyAction: TypeAlias = torch.Tensor
RobotAction: TypeAlias = dict[str, Any]
EnvAction: TypeAlias = np.ndarray
RobotObservation: TypeAlias = dict[str, Any]
PolicyAction = torch.Tensor
RobotAction = dict[str, Any]
EnvAction = np.ndarray
RobotObservation = dict[str, Any]
EnvTransition = TypedDict(
+4 -4
View File
@@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
from typing import Any, TypedDict, TypeVar, cast
import torch
from huggingface_hub import hf_hub_download
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
@dataclass
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
class DataProcessorPipeline[TInput, TOutput](HubMixin):
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
This class chains together multiple `ProcessorStep` instances to form a complete
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
# Type aliases for semantic clarity.
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
class ObservationProcessorStep(ProcessorStep, ABC):
+1 -1
View File
@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
@@ -15,7 +15,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import TypeAlias
from lerobot.cameras import CameraConfig
@@ -50,5 +49,5 @@ class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig):
pass
SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig
SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig
SO100FollowerConfig = SOFollowerRobotConfig
SO101FollowerConfig = SOFollowerRobotConfig
@@ -17,7 +17,6 @@
import logging
import time
from functools import cached_property
from typing import TypeAlias
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
@@ -230,5 +229,5 @@ class SOFollower(Robot):
logger.info(f"{self} disconnected.")
SO100Follower: TypeAlias = SOFollower
SO101Follower: TypeAlias = SOFollower
SO100Follower = SOFollower
SO101Follower = SOFollower
+1
View File
@@ -56,6 +56,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config,
omx_leader,
openarm_leader,
openarm_mini,
so_leader,
unitree_g1,
)
+4 -1
View File
@@ -132,10 +132,13 @@ def visualize_dataset(
logging.info("Logging to Rerun")
first_index = None
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
if first_index is None:
first_index = batch["index"][0].item()
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time("frame_index", sequence=batch["frame_index"][i].item())
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
# display each camera image
+141 -60
View File
@@ -21,6 +21,9 @@ This script allows you to delete episodes, split datasets, merge datasets,
remove features, modify tasks, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
Path semantics (v2): --root and --new_root are exact dataset folders containing
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
Usage Examples:
Delete episodes 0, 2, and 5 from a dataset:
@@ -29,16 +32,31 @@ Delete episodes 0, 2, and 5 from a dataset:
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Delete episodes and save to a new dataset:
Delete episodes from a local dataset at a specific path:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_filtered \
--root /path/to/pusht \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Split dataset by fractions:
Delete episodes and save to a new dataset at a specific path and with a new repo_id:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_filtered \
--new_root /path/to/pusht_filtered \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Split dataset by fractions (pusht_train, pusht_val):
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.8, "val": 0.2}'
Split dataset by fractions and save split datasets to a specific folder (base_folder/train, base_folder/val):
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_root /path/to/base_folder \
--operation.type split \
--operation.splits '{"train": 0.8, "val": 0.2}'
@@ -56,15 +74,29 @@ Split into more than two splits:
Merge multiple datasets:
lerobot-edit-dataset \
--repo_id lerobot/pusht_merged \
--new_repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
Merge multiple datasets to a specific output path:
lerobot-edit-dataset \
--new_repo_id lerobot/pusht_merged \
--new_root /path/to/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
Merge multiple datasets from a list of local dataset paths:
lerobot-edit-dataset \
--new_repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['pusht_train', 'pusht_val']" \
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
Remove camera feature:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
--operation.feature_names "['observation.image']"
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
lerobot-edit-dataset \
@@ -88,8 +120,8 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m
Convert image dataset to video format and save locally:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir /path/to/output/pusht_video
--new_root /path/to/output/pusht_video \
--operation.type convert_image_to_video
Convert image dataset to video format and save with new repo_id:
lerobot-edit-dataset \
@@ -167,6 +199,7 @@ class SplitConfig(OperationConfig):
@dataclass
class MergeConfig(OperationConfig):
repo_ids: list[str] | None = None
roots: list[str] | None = None
@OperationConfig.register_subclass("remove_feature")
@@ -200,36 +233,46 @@ class ConvertImageToVideoConfig(OperationConfig):
@OperationConfig.register_subclass("info")
@dataclass
class InfoConfig(OperationConfig):
type: str = "info"
show_features: bool = False
@dataclass
class EditDatasetConfig:
repo_id: str
# Operation configuration.
operation: OperationConfig
# Input dataset identifier. Always required unless for Merge operation.
repo_id: str | None = None
# Root directory where the input dataset is stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None
# Edited dataset identifier. When both new_repo_id (resp. new_root) and repo_id (resp. root) are identical, modifications are applied in-place and a backup of the original dataset is created. Required for Merge operation.
new_repo_id: str | None = None
# Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/new_repo_id. For Split operation, this is the base directory for the split datasets.
new_root: str | None = None
# Upload dataset to Hugging Face hub.
push_to_hub: bool = False
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
if new_repo_id:
output_repo_id = new_repo_id
output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id
else:
output_repo_id = repo_id
dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id
old_path = Path(str(dataset_path) + "_old")
def get_output_path(
repo_id: str,
new_repo_id: str | None,
root: Path | str | None,
new_root: Path | str | None,
) -> tuple[str, Path]:
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id
if dataset_path.exists():
if old_path.exists():
shutil.rmtree(old_path)
shutil.move(str(dataset_path), str(old_path))
output_repo_id = new_repo_id if new_repo_id else repo_id
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
output_dir = dataset_path
# In case of in-place modification, create a backup of the original dataset (if it exists)
if output_path == input_path:
backup_path = input_path.with_name(input_path.name + "_old")
return output_repo_id, output_dir
if input_path.exists():
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(input_path, backup_path)
return output_repo_id, output_path
def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
@@ -241,11 +284,15 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
cfg.repo_id,
new_repo_id=cfg.new_repo_id,
root=cfg.root,
new_root=cfg.new_root,
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
# In case of in-place modification, make the dataset point to the backup directory
if output_dir == dataset.root:
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
new_dataset = delete_episodes(
@@ -272,19 +319,27 @@ def handle_split(cfg: EditDatasetConfig) -> None:
"splits dict must be specified with split names as keys and fractions/episode lists as values"
)
if cfg.new_repo_id is not None:
logging.warning(
"split uses the original dataset identifier --repo_id to generate split names. The --new_repo_id parameter is ignored."
)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
split_datasets = split_dataset(
dataset,
splits=cfg.operation.splits,
output_dir=cfg.new_root,
)
for split_name, split_ds in split_datasets.items():
split_repo_id = f"{cfg.repo_id}_{split_name}"
logging.info(
f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
)
if cfg.push_to_hub:
logging.info(f"Pushing {split_name} split to hub as {split_repo_id}")
logging.info(f"Pushing {split_name} split to hub as {split_ds.repo_id}")
LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
@@ -295,18 +350,29 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
if not cfg.operation.repo_ids:
raise ValueError("repo_ids must be specified for merge operation")
if not cfg.repo_id:
raise ValueError("repo_id must be specified as the output repository for merged dataset")
if cfg.repo_id is not None or cfg.root is not None:
logging.warning(
"merge uses --new_repo_id and --new_root for the merged dataset. The --repo_id and --root parameters are ignored."
)
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
if cfg.operation.roots:
if len(cfg.operation.roots) != len(cfg.operation.repo_ids):
raise ValueError("repo_ids and roots must have the same length for merge operation")
logging.info(f"Loading {len(cfg.operation.roots)} datasets to merge")
datasets = [
LeRobotDataset(repo_id=repo_id, root=root)
for repo_id, root in zip(cfg.operation.repo_ids, cfg.operation.roots, strict=True)
]
else:
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
datasets = [LeRobotDataset(repo_id) for repo_id in cfg.operation.repo_ids]
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
output_dir = Path(cfg.new_root) if cfg.new_root else HF_LEROBOT_HOME / cfg.new_repo_id
logging.info(f"Merging datasets into {cfg.repo_id}")
logging.info(f"Merging datasets into {cfg.new_repo_id}")
merged_dataset = merge_datasets(
datasets,
output_repo_id=cfg.repo_id,
output_repo_id=cfg.new_repo_id,
output_dir=output_dir,
)
@@ -316,7 +382,7 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
)
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {cfg.repo_id}")
logging.info(f"Pushing to hub as {cfg.new_repo_id}")
LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
@@ -329,11 +395,15 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
cfg.repo_id,
new_repo_id=cfg.new_repo_id,
root=cfg.root,
new_root=cfg.new_root,
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
# In case of in-place modification, make the dataset point to the backup directory
if output_dir == dataset.root:
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
new_dataset = remove_feature(
@@ -361,9 +431,10 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if new_task is None and episode_tasks_raw is None:
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
# Warn about in-place modification behavior
if cfg.new_repo_id is not None:
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
if cfg.new_repo_id is not None or cfg.new_root is not None:
logging.warning(
"modify_tasks modifies datasets in-place. The --new_repo_id and --new_root parameters are ignored."
)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
@@ -399,32 +470,30 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
# Determine output directory and repo_id
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
# Priority: 1) new_root, 2) new_repo_id, 3) operation.output_dir, 4) auto-generated name
output_dir_config = getattr(cfg.operation, "output_dir", None)
if output_dir_config:
logging.warning(
"--operation.output_dir is deprecated and will be removed in future versions. "
"Please use --new_root instead."
)
if cfg.new_repo_id:
# Use new_repo_id for both local storage and hub push
if cfg.new_root:
output_dir = Path(cfg.new_root)
output_repo_id = cfg.new_repo_id or f"{cfg.repo_id}_video"
logging.info(f"Saving to new_root: {output_dir} as {output_repo_id}")
elif cfg.new_repo_id:
output_repo_id = cfg.new_repo_id
# Place new dataset as a sibling to the original dataset
# Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir)
# Extract just the dataset name (after last slash) for the local directory
local_dir_name = cfg.new_repo_id.split("/")[-1]
output_dir = dataset.root.parent / local_dir_name
output_dir = HF_LEROBOT_HOME / cfg.new_repo_id
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
elif output_dir_config:
# Use custom output directory for local-only storage
output_dir = Path(output_dir_config)
# Extract repo name from output_dir for the dataset
output_repo_id = output_dir.name
logging.info(f"Saving to local directory: {output_dir}")
logging.info(f"Saving to local directory: {output_dir} as {output_repo_id}")
else:
# Auto-generate name: append "_video" to original repo_id
output_repo_id = f"{cfg.repo_id}_video"
# Place new dataset as a sibling to the original dataset
# Extract just the dataset name (after last slash) for the local directory
local_dir_name = output_repo_id.split("/")[-1]
output_dir = dataset.root.parent / local_dir_name
logging.info(f"Saving to auto-generated location: {output_dir}")
output_dir = HF_LEROBOT_HOME / output_repo_id
logging.info(f"Saving to auto-generated location: {output_dir} as {output_repo_id}")
logging.info(f"Converting dataset {cfg.repo_id} to video format")
@@ -499,8 +568,20 @@ def handle_info(cfg: EditDatasetConfig):
sys.stdout.write(f"{feature_dump_str}\n")
def _validate_config(cfg: EditDatasetConfig) -> None:
if isinstance(cfg.operation, MergeConfig):
if not cfg.new_repo_id:
raise ValueError("--new_repo_id is required for merge operation (the merged dataset identifier)")
else:
if not cfg.repo_id:
raise ValueError(
f"--repo_id is required for {cfg.operation.type} operation (the input dataset identifier)"
)
@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
_validate_config(cfg)
operation_type = cfg.operation.type
if operation_type == "delete_episodes":
@@ -61,6 +61,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config,
omx_leader,
openarm_leader,
openarm_mini,
so_leader,
)
from lerobot.utils.robot_utils import precise_sleep
+30 -75
View File
@@ -74,8 +74,6 @@ from pathlib import Path
from pprint import pformat
from typing import Any
import torch
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
)
@@ -92,7 +90,6 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
@@ -128,6 +125,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config,
omx_leader,
openarm_leader,
openarm_mini,
reachy2_teleoperator,
so_leader,
unitree_g1,
@@ -157,7 +155,7 @@ class DatasetRecordConfig:
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
@@ -228,9 +226,6 @@ class RecordConfig:
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
# Only applies when using a policy (not teleop)
interpolation_multiplier: int = 1
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -303,7 +298,6 @@ def record_loop(
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
interpolator: ActionInterpolator | None = None,
display_compressed_images: bool = False,
):
if dataset is not None and dataset.fps != fps:
@@ -340,14 +334,7 @@ def record_loop(
preprocessor.reset()
postprocessor.reset()
# Reset interpolator if provided
if interpolator is not None:
interpolator.reset()
# Calculate control interval based on interpolation
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
no_action_count = 0
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
@@ -368,58 +355,24 @@ def record_loop(
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
# With interpolation: only call policy when interpolator needs new action
if use_interpolation:
# Get action keys from robot
action_keys = sorted(robot.action_features)
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
if interpolator.needs_new_action():
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
# Convert to tensor for interpolator
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
interpolator.add(action_tensor)
# Get interpolated action
interp_action = interpolator.get()
if interp_action is not None:
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
action_values = robot_action_to_send
else:
# No action available yet, skip this iteration
continue
else:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
elif policy is None and isinstance(teleop, Teleoperator):
act = teleop.get_action()
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
act_processed_teleop = teleop_action_processor((act, obs))
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
elif policy is None and isinstance(teleop, list):
arm_action = teleop_arm.get_action()
@@ -428,15 +381,23 @@ def record_loop(
base_action = robot._from_keyboard_to_base_action(keyboard_action)
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs))
else:
no_action_count += 1
if no_action_count == 1 or no_action_count % 10 == 0:
logging.warning(
"No policy or teleoperator provided, skipping action generation. "
"This is likely to happen when resetting the environment without a teleop device. "
"The robot won't be at its rest position at the start of the next episode."
)
continue
# Applies a pipeline to the action, default is IdentityProcessor
if policy is not None and act_processed_policy is not None:
action_values = act_processed_policy
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
else:
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
"This is likely to happen when resetting the environment without a teleop device."
"The robot won't be at its rest position at the start of the next episode."
)
continue
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
@@ -457,7 +418,7 @@ def record_loop(
dt_s = time.perf_counter() - start_loop_t
sleep_time_s: float = control_interval - dt_s
sleep_time_s: float = 1 / fps - dt_s
if sleep_time_s < 0:
logging.warning(
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
@@ -544,7 +505,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
preprocessor = None
postprocessor = None
interpolator = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
@@ -555,10 +515,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Create interpolator for smoother policy control
if cfg.interpolation_multiplier > 1:
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
robot.connect()
if teleop is not None:
@@ -590,7 +546,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
interpolator=interpolator,
display_compressed_images=display_compressed_images,
)
+1 -1
View File
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
@@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_mini,
so_leader,
)
@@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [
"koch_leader",
"omx_follower",
"omx_leader",
"openarm_mini",
"so100_follower",
"so100_leader",
"so101_follower",
@@ -94,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config,
omx_leader,
openarm_leader,
openarm_mini,
reachy2_teleoperator,
so_leader,
unitree_g1,
+17 -2
View File
@@ -24,6 +24,7 @@ import torch
from accelerate import Accelerator
from termcolor import colored
from torch.optim import Optimizer
from tqdm import tqdm
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
format_big_number,
has_method,
init_logging,
inside_slurm,
)
@@ -378,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
# Use effective batch size for proper epoch calculation in distributed training
# Keep global batch size for logging; MetricsTracker handles world size internally.
effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
effective_batch_size,
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
progbar = tqdm(
total=cfg.steps - step,
desc="Training",
unit="step",
disable=inside_slurm(),
position=0,
leave=True,
)
logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
)
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
accelerator.wait_for_everyone()
if is_main_process:
progbar.close()
if eval_env:
close_envs(eval_env)
@@ -306,7 +306,7 @@ def train_fast_tokenizer(
# download the tokenizer source code (not pretrained weights)
# we'll train a new tokenizer on our own data
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
base_tokenizer = AutoProcessor.from_pretrained("lerobot/fast-action-tokenizer", trust_remote_code=True)
# convert action_chunks array to list of arrays (expected by .fit())
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
@@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,18 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
from .config_openarm_mini import OpenArmMiniConfig
from .openarm_mini import OpenArmMini
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
__all__ = [
"ActionInterpolator",
"ActionQueue",
"LatencyTracker",
"RTCConfig",
"RTCProcessor",
]
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
@@ -0,0 +1,30 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
use_degrees: bool = True
@@ -0,0 +1,296 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from lerobot.processor import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..teleoperator import Teleoperator
from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
class OpenArmMini(Teleoperator):
"""
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
"""
config_class = OpenArmMiniConfig
name = "openarm_mini"
def __init__(self, config: OpenArmMiniConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
motors_left = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
)
@property
def action_features(self) -> dict[str, type]:
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.bus_right.is_connected and self.bus_left.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate:
self.calibrate()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArm Mini.
1. Disable torque
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions)
5. Save calibration
"""
if self.calibration:
user_input = input(
f"Press ENTER to use existing calibration for {self.id}, "
f"or type 'c' and press ENTER to run new calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
"""Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
homing_offsets = bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
if self.calibration is None:
self.calibration = {}
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
max_res = motor_resolution - 1
for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper":
input(
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
f"Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..."
)
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos:
range_min = int(closed_pos)
range_max = int(open_pos)
drive_mode = 0
else:
range_min = int(open_pos)
range_max = int(closed_pos)
drive_mode = 1
logger.info(
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})"
)
else:
range_min = 0
range_max = max_res
drive_mode = 0
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name],
range_min=range_min,
range_max=range_max,
)
cal_for_bus = {
k.replace(f"{arm_name}_", ""): v
for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None:
self.bus_right.disable_torque()
self.bus_right.configure_motors()
for motor in self.bus_right.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None:
print("\nSetting up RIGHT arm motors...")
for motor in reversed(self.bus_right.motors):
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> RobotAction:
"""Get current action from both arms (read positions from all motors)."""
start = time.perf_counter()
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
action: dict[str, Any] = {}
for motor, val in right_positions.items():
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
@check_if_not_connected
def disconnect(self) -> None:
self.bus_right.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.")
@@ -15,7 +15,6 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TypeAlias
from ..config import TeleoperatorConfig
@@ -38,5 +37,5 @@ class SOLeaderTeleopConfig(TeleoperatorConfig, SOLeaderConfig):
pass
SO100LeaderConfig: TypeAlias = SOLeaderTeleopConfig
SO101LeaderConfig: TypeAlias = SOLeaderTeleopConfig
SO100LeaderConfig = SOLeaderTeleopConfig
SO101LeaderConfig = SOLeaderTeleopConfig
@@ -16,7 +16,6 @@
import logging
import time
from typing import TypeAlias
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
@@ -156,5 +155,5 @@ class SOLeader(Teleoperator):
logger.info(f"{self} disconnected.")
SO100Leader: TypeAlias = SOLeader
SO101Leader: TypeAlias = SOLeader
SO100Leader = SOLeader
SO101Leader = SOLeader
+4
View File
@@ -95,6 +95,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .bi_openarm_leader import BiOpenArmLeader
return BiOpenArmLeader(config)
elif config.type == "openarm_mini":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))
+1 -1
View File
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
)
# Check if dataset_name does not start with "eval_" but policy is provided
+1 -3
View File
@@ -16,12 +16,10 @@
import json
import warnings
from pathlib import Path
from typing import TypeVar
import imageio
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
T = TypeVar("T", bound=JsonLike)
def write_video(video_path, stacked_frames, fps):
@@ -33,7 +31,7 @@ def write_video(video_path, stacked_frames, fps):
imageio.mimsave(video_path, stacked_frames, fps=fps)
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
def deserialize_json_into_object[T: JsonLike](fpath: Path, obj: T) -> T:
"""
Loads the JSON data from `fpath` and recursively fills `obj` with the
corresponding values (strictly matching structure and types).
+4 -2
View File
@@ -104,9 +104,10 @@ class MetricsTracker:
self.metrics = metrics
self.steps = initial_step
world_size = accelerator.num_processes if accelerator else 1
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
self.samples = self.steps * self._batch_size
self.samples = self.steps * self._batch_size * world_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
self.accelerator = accelerator
@@ -132,7 +133,8 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
world_size = self.accelerator.num_processes if self.accelerator else 1
self.samples += self._batch_size * world_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
size 47424
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
size 68
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
size 47424
+1 -1
View File
@@ -222,7 +222,7 @@ def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
ids = list(range(total_tasks))
tasks = [f"Perform action {i}." for i in ids]
df = pd.DataFrame({"task_index": ids}, index=tasks)
df = pd.DataFrame({"task_index": ids}, index=pd.Index(tasks, name="task"))
return df
return _create_tasks
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@@ -37,6 +38,9 @@ def test_classifier_output():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -78,6 +82,9 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -117,6 +124,9 @@ def test_multiclass_classifier():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -129,6 +139,9 @@ def test_default_device():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -17,7 +17,6 @@
"""Test script to verify PI0Fast policy integration with LeRobot vs the original implementation"""
# ruff: noqa: E402
import os
import random
from copy import deepcopy
from typing import Any
@@ -28,10 +27,6 @@ import torch
pytest.importorskip("transformers")
pytest.importorskip("scipy")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires accepting the model license",
)
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
@@ -53,22 +48,23 @@ DUMMY_STATE_DIM = 20
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
NUM_VIEWS = 2 # Number of camera views
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "cuda"
MODEL_PATH_LEROBOT = "lerobot/pi0fast-base"
# Expected action token shape: (batch_size, max_decoding_steps)
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
# Expected first 5 action tokens (for reproducibility check)
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362])
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255020, 255589])
# Expected actions after detokenization
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
EXPECTED_ACTIONS_MEAN = 0.04419417306780815
EXPECTED_ACTIONS_STD = 0.26231569051742554
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000])
EXPECTED_ACTIONS_MEAN = 0.046403881162405014
EXPECTED_ACTIONS_STD = 0.2607129216194153
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000])
@require_cuda
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
random.seed(seed)
@@ -85,6 +81,7 @@ def set_seed_all(seed: int):
torch.use_deterministic_algorithms(True, warn_only=True)
@require_cuda
def instantiate_lerobot_pi0_fast(
from_pretrained: bool = False,
model_path: str = MODEL_PATH_LEROBOT,
@@ -127,6 +124,7 @@ def instantiate_lerobot_pi0_fast(
return policy, preprocessor, postprocessor
@require_cuda
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 1
@@ -158,22 +156,25 @@ def create_dummy_data(device=DEVICE):
# Pytest fixtures
@pytest.fixture(scope="module")
@require_cuda
def pi0_fast_components():
"""Fixture to instantiate and provide all PI0Fast components for tests."""
print(f"\nTesting with DEVICE='{DEVICE}'")
print("\n[Setup] Instantiating LeRobot PI0Fast policy...")
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_pi0_fast(from_pretrained=True)
print("Model loaded successfully")
yield policy_obj, preprocessor_obj, postprocessor_obj
return policy_obj, preprocessor_obj, postprocessor_obj
@pytest.fixture(scope="module")
@require_cuda
def policy(pi0_fast_components):
"""Fixture to provide the PI0Fast policy for tests."""
return pi0_fast_components[0]
@pytest.fixture(scope="module")
@require_cuda
def preprocessor(pi0_fast_components):
"""Fixture to provide the PI0Fast preprocessor for tests."""
return pi0_fast_components[1]
-9
View File
@@ -16,17 +16,8 @@
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0 import ( # noqa: E402
PI0Config,
+1 -11
View File
@@ -16,25 +16,15 @@
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
import os
import pytest
import torch
from lerobot.utils.random_utils import set_seed
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi05 import ( # noqa: E402
PI05Config,
PI05Policy,
make_pi05_pre_post_processors, # noqa: E402
)
from lerobot.utils.random_utils import set_seed
from tests.utils import require_cuda # noqa: E402
+2 -1
View File
@@ -24,9 +24,10 @@ import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
reason="TODO: This test seems to hang the CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
+3 -1
View File
@@ -24,9 +24,10 @@ import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
reason="TODO: This test seems to hang the CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
@@ -88,6 +89,7 @@ def test_pi0_rtc_initialization_without_rtc_config():
print("✓ PI0 RTC initialization without RTC config: Test passed")
@require_cuda
def test_pi0_rtc_inference_with_prev_chunk():
"""Test PI0 policy inference with RTC and previous chunk."""
set_seed(42)

Some files were not shown because too many files have changed in this diff Show More