Compare commits

..

2 Commits

Author SHA1 Message Date
Jade Choghari 435f12f6e4 pre-commit 2026-02-10 13:07:40 +01:00
Stepan Feduniak a1193df2d7 fix the types 2026-01-31 09:56:15 +00:00
184 changed files with 1267 additions and 7980 deletions
+1 -8
View File
@@ -44,7 +44,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.12"
PYTHON_VERSION: "3.10"
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
@@ -61,7 +61,6 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -90,11 +89,5 @@ jobs:
- name: Install lerobot with test extras
run: uv sync --extra "test"
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest
run: uv run pytest tests -vv --maxfail=10
+4 -21
View File
@@ -37,7 +37,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.12"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest action is built, canceling older runs.
@@ -60,7 +60,6 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -88,12 +87,6 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
@@ -108,11 +101,9 @@ jobs:
runs-on:
group: aws-general-8-plus
if: |
github.repository == 'huggingface/lerobot' && (
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
)
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
outputs:
image_tag: ${{ steps.set_tag.outputs.image_tag }}
env:
@@ -169,7 +160,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -181,13 +171,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Fix ptxas permissions
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
+4 -20
View File
@@ -28,7 +28,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.12"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
@@ -119,7 +119,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
@@ -131,11 +130,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on CPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -152,7 +146,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -164,11 +157,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -186,7 +174,6 @@ jobs:
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
CUDA_VISIBLE_DEVICES: "0,1,2,3"
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -198,15 +185,12 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Verify GPU availability
run: |
nvidia-smi
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
run: pytest -vv tests/training/
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10
+1 -1
View File
@@ -50,7 +50,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
python-version: '3.10'
- name: Run pre-commit hooks
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
+10 -2
View File
@@ -22,7 +22,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.12"
PYTHON_VERSION: "3.10"
jobs:
# This job builds the Python package and publishes it to PyPI
@@ -45,7 +45,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
python-version: '3.10'
- name: Extract Version
id: extract_info
@@ -83,6 +83,14 @@ jobs:
exit 1
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
run: python -m pip install build
+2 -14
View File
@@ -29,7 +29,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.12"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
# Ensures that only the latest action is built, canceling older runs.
@@ -48,7 +48,6 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -80,11 +79,7 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv
@@ -96,7 +91,6 @@ jobs:
name: Build and Push Docker
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
env:
@@ -142,7 +136,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -154,11 +147,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv
- name: Run end-to-end tests
+2 -2
View File
@@ -13,7 +13,7 @@
# limitations under the License.
default_language_version:
python: python3.12
python: python3.10
exclude: "tests/artifacts/.*\\.safetensors$"
@@ -55,7 +55,7 @@ repos:
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py312-plus]
args: [--py310-plus]
##### Markdown Quality #####
- repo: https://github.com/rbubley/mirrors-prettier
-25
View File
@@ -1,25 +0,0 @@
# AI Usage Policy
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
## Remember the Human Maintainers
Please remember that LeRobot is maintained by a dedicated team of humans.
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
## AI is Welcome Here
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
+1 -1
View File
@@ -2,7 +2,7 @@
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
## Ways to Contribute
-1
View File
@@ -1,3 +1,2 @@
include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json
+5 -5
View File
@@ -100,11 +100,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| Category | Models |
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
+42 -42
View File
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
aliberts/aloha_mobile_shrimp_image \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 2 20 None \
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
aliberts/aloha_mobile_shrimp_image \
aliberts/paris_street \
aliberts/kitchen \
--vcodec libsvtav1 \
--pix-fmt yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_size_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
| | | vcodec | pix_fmt | | | |
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
+1 -3
View File
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
# Define Python version argument
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.10
# Configure environment variables
ENV DEBIAN_FRONTEND=noninteractive \
@@ -85,8 +85,6 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
RUN uv pip install --no-cache ".[all]"
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
# Copy the rest of the application source code
# Make sure to have the git-LFS files for testing
COPY --chown=user_lerobot:user_lerobot . .
+1 -1
View File
@@ -19,7 +19,7 @@
# docker run -it --rm lerobot-user
# Configure the base image
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.10
FROM python:${PYTHON_VERSION}-slim
# Configure environment variables
-4
View File
@@ -29,8 +29,6 @@
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
- sections:
- local: act
@@ -47,8 +45,6 @@
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
title: Multitask DiT Policy
- local: walloss
title: WALL-OSS
title: "Policies"
-3
View File
@@ -88,8 +88,5 @@ lerobot-record \
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
--dataset.num_episodes=10 \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--policy.path=${HF_USER}/act_policy
```
+1 -1
View File
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
+4 -4
View File
@@ -32,7 +32,7 @@ version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.12"
requires-python = ">= 3.11"
[build-system]
build-backend = # your-build-backend
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
# modeling_my_custom_policy.py
import torch
import torch.nn as nn
from typing import Any
from typing import Dict, Any
from lerobot.policies.pretrained import PreTrainedPolicy
from .configuration_my_custom_policy import MyCustomPolicyConfig
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
super().__init__(config, dataset_stats)
...
```
@@ -102,7 +102,7 @@ Create processor functions:
```python
# processor_my_custom_policy.py
from typing import Any
from typing import Dict, Any
import torch
+4 -7
View File
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
### Hardware
- EarthRover Mini robot
- Computer with Python 3.12 or newer
- Computer with Python 3.10 or newer
- Internet connection
### Setting Up the Frodobots SDK
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face username:
```bash
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
@@ -185,16 +185,13 @@ echo $HF_USER
Use the standard recording command:
```bash
lerobot-record \
python src/lerobot/scripts/lerobot_record.py \
--robot.type=earthrover_mini_plus \
--teleop.type=keyboard_rover \
--dataset.repo_id=your_username/dataset_name \
--dataset.num_episodes=2 \
--dataset.fps=10 \
--dataset.single_task="Navigate around obstacles" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
+2 -2
View File
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
pip install huggingface_hub
# Login to Hugging Face
hf auth login
huggingface-cli login
# Create a new repository
hf repo create my-org/my-custom-env
huggingface-cli repo create my-custom-env --type space --org my-org
# Initialize git and push
git init
+3 -6
View File
@@ -120,12 +120,9 @@ lerobot-record \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \
--dataset.single_task="Grab and handover the red cube to the other arm"
--policy.path=<user>/groot-bimanual # your trained model
--dataset.episode_time_s=30
--dataset.reset_time_s=10
```
+5 -11
View File
@@ -224,15 +224,12 @@ lerobot-record \
--teleop.port=/dev/tty.usbmodem1201 \
--teleop.id=right \
--teleop.side=right \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--dataset.single_task="Hand recording test with video data" \
--dataset.num_episodes=1 \
--dataset.episode_time_s=5 \
--dataset.push_to_hub=true \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
@@ -244,7 +241,7 @@ lerobot-replay \
--robot.port=/dev/tty.usbmodem58760432281 \
--robot.id=right \
--robot.side=right \
--dataset.repo_id=<USER>/hand_record_test_with_camera \
--dataset.repo_id=nepyope/hand_record_test_with_camera \
--dataset.episode=0
```
@@ -252,13 +249,13 @@ lerobot-replay \
```bash
lerobot-train \
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
--policy.type=act \
--output_dir=outputs/train/hopejr_hand \
--job_name=hopejr \
--policy.device=mps \
--wandb.enable=true \
--policy.repo_id=<USER>/hand_test_policy
--policy.repo_id=nepyope/hand_test_policy
```
### Evaluate
@@ -273,11 +270,8 @@ lerobot-record \
--robot.side=right \
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
--display_data=false \
--dataset.repo_id=<USER>/eval_hopejr \
--dataset.repo_id=nepyope/eval_hopejr \
--dataset.single_task="Evaluate hopejr hand policy" \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
```
+6 -12
View File
@@ -159,13 +159,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
HF_USER=$(hf auth whoami | head -n 1)
echo $HF_USER
```
@@ -185,10 +185,7 @@ lerobot-record \
--display_data=true \
--dataset.repo_id=${HF_USER}/record-test \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
--dataset.single_task="Grab the black cube"
```
</hfoption>
<hfoption id="API example">
@@ -327,7 +324,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
You can also push your local dataset to the Hub manually, running:
```bash
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
```
#### Record function
@@ -491,7 +488,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
Once training is done, upload the latest checkpoint with:
```bash
hf upload ${HF_USER}/act_so101_test \
huggingface-cli upload ${HF_USER}/act_so101_test \
outputs/train/act_so101_test/checkpoints/last/pretrained_model
```
@@ -499,7 +496,7 @@ You can also upload intermediate checkpoints with:
```bash
CKPT=010000
hf upload ${HF_USER}/act_so101_test${CKPT} \
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
```
@@ -518,9 +515,6 @@ lerobot-record \
--display_data=false \
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
+8 -14
View File
@@ -1,20 +1,18 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Step 2: Environment Setup
## Environment Setup
Create a virtual environment with Python 3.12, using conda:
Create a virtual environment with Python 3.10, using conda:
```bash
conda create -y -n lerobot python=3.12
conda create -y -n lerobot python=3.10
```
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
@@ -40,14 +38,7 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
> [!NOTE]
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command:
>
> ```bash
> conda install evdev -c conda-forge
> ```
## Step 3: Install LeRobot 🤗
## Install LeRobot 🤗
### From Source
@@ -90,6 +81,9 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
+2 -2
View File
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
+1 -4
View File
@@ -41,10 +41,7 @@ lerobot-record \
--display_data=true \
--dataset.repo_id=${HF_USER}/record-test \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
--dataset.single_task="Grab the black cube"
```
See the [recording guide](./il_robots#record-a-dataset) for more details.
-340
View File
@@ -1,340 +0,0 @@
# Multitask DiT Policy
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
## Model Overview
The model uses:
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.
## Installation Requirements
Multitask DiT Policy has additional dependencies. Install it with:
```bash
pip install lerobot[multi_task_dit]
```
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
## Usage
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
```python
policy.type=multi_task_dit
```
## Training
### Basic Training Command
Here's a complete training command for training Multitask DiT on your dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/multitask_dit_training \
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
For reliable performance, start with these suggested default hyperparameters:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true
```
**Key Parameters:**
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task
### Training Configuration Parameters
#### Objective Selection
Choose between diffusion and flow matching:
```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
--policy.clip_sample=true \ # Clip samples during denoising
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
```
#### Transformer Architecture
Adjust model capacity based on dataset size:
```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512 \
--policy.num_heads=8 # should ideally be hidden_dim // 64
```
**Positional Encoding Options:**
The model supports two positional encoding methods for action sequences:
```bash
# Rotary Position Embedding (RoPE) - default, recommended
--policy.use_rope=true \
--policy.rope_base=10000.0 # Base frequency for RoPE
# Absolute positional encoding
--policy.use_positional_encoding=true # Disables RoPE when true
```
**Other Transformer Parameters:**
```bash
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
--policy.timestep_embed_dim=256 # Timestep embedding dimension
```
#### Vision Encoder Configuration
```bash
# Use different CLIP model for more expressivity at the cost of inference time
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
--policy.vision_encoder_name=openai/clip-vit-large-patch14
# Use separate vision encoder per camera
# This may be useful when cameras have significantly different characteristics, but
# be wary of increased VRAM footprint.
--policy.use_separate_rgb_encoder_per_camera=true
# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```
#### Text Encoder Configuration
```bash
# Use different CLIP text encoder model
# same as vision: experiment with larger or smaller models depending on the
# complexity of your tasks and size of dataset
--policy.text_encoder_name=openai/clip-vit-large-patch14
```
#### Learning Rate Configuration
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```
### Training Tuning Guidelines
#### 1. Flow Matching with Beta Sampling
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
Consider testing the flow-matching objective and evaluating performance differences for your task:
```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
#### 2. Number of Transformer Layers
Match model capacity to your dataset size:
- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers
#### 3. `horizon` Tuning
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
#### 4. `n_action_steps` Sensitivity
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
### Inference Tuning
For faster inference, use DDIM with fewer sampling steps:
```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```
### Resuming Training
To resume training from a checkpoint:
```bash
lerobot-train \
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
## Common Failure Modes and Debugging
Training these models can be finicky. Here are common failure modes and debugging approaches:
### Idling / No Motion
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
**Debugging tips:**
- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction
### Executing the Wrong Task
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
**Potential causes:**
- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset
**Debugging tips:**
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning
### Training Instability
If training loss is unstable or diverging:
- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly
## Performance Considerations
### GPU Requirements
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
### Batch Size Recommendations
- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)
## Example: Training on Custom Dataset
Here's a complete example training on a custom dataset:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/mutitask_dit_training \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
--wandb.enable=true \
--wandb.project=multitask_dit
```
## References
For more details on the technical implementation and architecture, see:
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)
+5 -9
View File
@@ -66,13 +66,12 @@ Run on of the examples scripts to teleoperate, record a dataset, replay a datase
All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port.
Additionally you need to **copy the URDF of the robot into the examples folder**. For the examples in this tutorial (using SO100/SO101), copy the `SO101` folder from the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101) into the `examples/phone_to_so100/` directory, so that the URDF file path becomes `examples/phone_to_so100/SO101/so101_new_calib.urdf`.
Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf)
- Run this example to teleoperate:
```bash
cd examples/phone_to_so100
python teleoperate.py
python examples/phone_to_so100/teleoperate.py
```
After running the example:
@@ -85,22 +84,19 @@ Additionally you can customize mapping or safety limits by editing the processor
- Run this example to record a dataset, which saves absolute end effector observations and actions:
```bash
cd examples/phone_to_so100
python record.py
python examples/phone_to_so100/record.py
```
- Run this example to replay recorded episodes:
```bash
cd examples/phone_to_so100
python replay.py
python examples/phone_to_so100/replay.py
```
- Run this example to evaluate a pretrained policy:
```bash
cd examples/phone_to_so100
python evaluate.py
python examples/phone_to_so100/evaluate.py
```
### Important pipeline steps and options
+6 -1
View File
@@ -34,6 +34,11 @@ As described by Physical Intelligence, while AI has achieved remarkable success
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
@@ -55,7 +60,7 @@ policy.type=pi0
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--output_dir=./outputs/pi0_training \
+6 -1
View File
@@ -36,6 +36,11 @@ This diverse training mixture creates a "curriculum" that enables generalization
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
@@ -51,7 +56,7 @@ policy.type=pi05
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py\
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--output_dir=./outputs/pi05_training \
+15 -10
View File
@@ -43,11 +43,16 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training a Custom FAST Tokenizer
You have two options for the FAST tokenizer:
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
@@ -109,15 +114,15 @@ lerobot-train \
### Key Training Parameters
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
## Inference
@@ -1,37 +0,0 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```
-6
View File
@@ -159,9 +159,6 @@ lerobot-record \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
@@ -201,9 +198,6 @@ lerobot-record \
--dataset.fps=15 \
--dataset.push_to_hub=true \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
--display_data=true
```
+4 -4
View File
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
Train with **no annotations** - uses linear progress from 0 to 1:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=single_stage \
@@ -288,7 +288,7 @@ lerobot-train \
Train with **dense annotations only** (sparse auto-generated):
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dense_only \
@@ -307,7 +307,7 @@ lerobot-train \
Train with **both sparse and dense annotations**:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=sarm \
--policy.annotation_mode=dual \
@@ -468,7 +468,7 @@ This script:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
-3
View File
@@ -106,9 +106,6 @@ lerobot-record \
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
--dataset.episode_time_s=50 \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
-155
View File
@@ -1,155 +0,0 @@
# Streaming Video Encoding Guide
## 1. Overview
Streaming video encoding eliminates the traditional PNG round-trip during video dataset recording. Instead of:
1. Capture frame -> write PNG to disk -> (at episode end) read PNG's -> encode to MP4 -> delete PNG's
Frames can be encoded in real-time during capture:
1. Capture frame -> queue to encoder thread -> encode to MP4 directly
This makes `save_episode()` near-instant (the video is already encoded by the time the episode ends) and removes the blocking wait that previously occurred between episodes, especially with multiple cameras in long episodes.
## 2. Tuning Parameters
| Parameter | CLI Flag | Type | Default | Description |
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
## 3. Performance Considerations
Streaming encoding means the CPU is encoding video **during** the capture loop, not after. This creates a CPU budget that must be shared between:
- **Control loop** (reading cameras, control the robot, writing non-video data)
- **Encoder threads** (one pool per camera)
- **Rerun visualization** (if enabled)
- **OS and other processes**
### Resolution & Number of Cameras Impact
| Setup | Throughput (px/sec) | CPU Encoding Load | Notes |
| ------------------------- | ------------------- | ----------------- | ------------------------------ |
| 2camsx 640x480x3 @30fps | 55M | Low | Works on most systems |
| 2camsx 1280x720x3 @30fps | 165M | Moderate | Comfortable on modern systems |
| 2camsx 1920x1080x3 @30fps | 373M | High | Requires powerful high-end CPU |
### `encoder_threads` Tuning
This parameter controls how many threads each encoder instance uses internally:
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
- **`None` (default)**: Lets the codec decide. Information available in the codec logs.
### Backpressure and Frame Dropping
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
1. The queue fills up (consuming RAM)
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
3. A warning is logged: `"Encoder queue full for {camera}, dropped N frame(s)"`
4. At episode end, total dropped frames per camera are reported
### Symptoms of Encoder Falling Behind
- **System feels laggy and freezes**: all CPUs are at 100%
- **Dropped frame warnings** in the log or lower frames/FPS than expected in the recorded dataset
- **Choppy robot movement**: If CPU is severely overloaded, even the capture loop may be affected
- **Accumulated rerun lag**: Visualization falls behind real-time
## 4. Hardware-Accelerated Encoding
### When to Use
Use HW encoding when:
- CPU is the bottleneck (dropped frames, choppy robot, rerun lag)
- You have compatible hardware (GPU or dedicated encoder)
- You're recording at high throughput (high resolution or with many cameras)
### Choosing a Codec
| Codec | CPU Usage | File Size | Quality | Notes |
| --------------------- | --------- | -------------- | ------- | ---------------------------------------------------------------- |
| `libsvtav1` (default) | High | Smallest | Best | Default. Best compression but most CPU-intensive |
| `h264` | Medium | ~30-50% larger | Good | Software H.264. Lower CPU |
| HW encoders | Very Low | Largest | Good | Offloads to dedicated hardware. Best for CPU-constrained systems |
### Available HW Encoders
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
> [!NOTE]
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
> [!NOTE]
> `libsvtav1` is the default because it provides the best training performance; other vcodecs can reduce CPU usage and be faster, but they typically produce larger files and may affect training time.
## 5. Troubleshooting
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
## 6. Recommended Configurations
These estimates are conservative; we recommend testing them on your setup—start with a low load and increase it gradually.
### High-End Systems: modern 12+ cores (24+ threads)
A throughput between ~250-500M px/sec should be comfortable in CPU. For even better results try HW encoding if available.
```bash
# 3camsx 1280x720x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
# 2camsx 1920x1080x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
lerobot-record --dataset.encoder_threads=5 ...
# 3camsx 1920x1080x3 @30fps: Might require some tuning.
```
### Mid-Range Systems: modern 8+ cores (16+ threads) or Apple Silicon
A throughput between ~80-300M px/sec should be possible in CPU.
```bash
# 3camsx 640x480x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
# 2camsx 1280x720x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
lerobot-record --dataset.encoder_threads=2 ...
# 2camsx 1920x1080x3 @30fps: Might require some tuning.
```
### Low-Resource Systems: modern 4+ cores (8+ threads) or Raspberry Pi 5
On very constrained systems, streaming encoding may compete too heavily with the capture loop. Disabling it falls back to the PNG-based approach where encoding happens between episodes (blocking, but doesn't interfere with capture). Alternatively, record at a lower throughput to reduce both capture and encoding load. Consider also changing codec to `h264` and using batch encoding.
```bash
# 2camsx 640x480x3 @30fps: Requires some tuning.
# Use H.264, disable streaming, consider batching encoding
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
```
## 7. Closing note
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
+6 -12
View File
@@ -123,7 +123,7 @@ SSH into the robot and install LeRobot:
```bash
ssh unitree@<YOUR_ROBOT_IP>
conda create -y -n lerobot python=3.12
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
@@ -153,7 +153,7 @@ With the robot server running, you can now control the robot remotely. Let's lau
### Step 1: Install LeRobot on your machine
```bash
conda create -y -n lerobot python=3.12
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
@@ -216,7 +216,7 @@ lerobot-teleoperate \
### Record Dataset in Simulation
```bash
lerobot-record \
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
@@ -229,10 +229,7 @@ lerobot-record \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
--dataset.push_to_hub=true
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
@@ -269,7 +266,7 @@ lerobot-teleoperate \
### Record Dataset on Real Robot
```bash
lerobot-record \
python -m lerobot.scripts.lerobot_record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
@@ -282,10 +279,7 @@ lerobot-record \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
--dataset.push_to_hub=true
```
**Note**: Update `server_address` to match your robot's camera server IP.
-25
View File
@@ -12,7 +12,6 @@ LeRobot provides several utilities for manipulating datasets:
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
@@ -157,30 +156,6 @@ lerobot-edit-dataset \
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Show the information of datasets
Show the information of datasets such as number of episode, number of frame, File size and so on.
No change will be made to the dataset
```bash
# Show dataset information without feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
# Show dataset information with feature details
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type info \
--operation.show_features true
```
**Parameters:**
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
### Push to Hub
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
+1 -1
View File
@@ -45,7 +45,7 @@ policy.type=wall_x
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
```bash
lerobot-train \
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=wall_x \
--output_dir=./outputs/wallx_training \
+1 -1
View File
@@ -154,7 +154,7 @@ lerobot-train \
```bash
lerobot-train \
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
+2 -2
View File
@@ -22,7 +22,7 @@ lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=<USER>/record-test \
--dataset.repo_id=aliberts/record-test \
--dataset.episode=2
```
"""
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
-490
View File
@@ -1,490 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SLURM-distributed SARM RA-BC annotation pipeline.
Computes SARM progress values for all frames in a dataset, distributed across
SLURM workers, then merges the shards into a single sarm_progress.parquet.
Two subcommands, each a separate SLURM submission:
compute N workers, each computes progress for a subset of episodes
aggregate 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
Usage:
python slurm_compute_rabc.py compute \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--stride 10 --device cpu --workers 50 --partition cpu
python slurm_compute_rabc.py aggregate \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--partition cpu --push-to-hub
"""
import argparse
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
class ComputeProgressShards(PipelineStep):
"""Each worker computes SARM progress for its assigned episodes."""
def __init__(
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
):
super().__init__()
if stride < 1:
raise ValueError(f"stride must be >= 1, got {stride}")
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.stride = stride
self.head_mode = head_mode
self.device = device
self.shard_dir = shard_dir
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.policies.sarm.compute_rabc_weights import (
generate_all_frame_indices,
interpolate_progress,
load_sarm_resources,
)
from lerobot.utils.utils import init_logging
init_logging()
dataset, reward_model, preprocess = load_sarm_resources(
self.repo_id,
self.reward_model_path,
self.device,
)
if hasattr(preprocess, "eval"):
preprocess.eval()
for step in preprocess.steps:
if hasattr(step, "eval"):
step.eval()
image_key = reward_model.config.image_key
state_key = reward_model.config.state_key
frame_gap = reward_model.config.frame_gap
center_idx = reward_model.config.n_obs_steps // 2
dual_mode = reward_model.config.uses_dual_heads
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
compute_dense = self.head_mode in ("dense", "both") and dual_mode
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
if not my_episodes:
logging.info(f"Rank {rank}: no episodes assigned")
return
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
all_rows = []
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
ep = dataset.meta.episodes[ep_idx]
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
task = dataset[ep_start].get("task", "perform the task")
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
if self.stride > 1:
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
if (ep_end - 1) not in compute_indices:
compute_indices.append(ep_end - 1)
compute_indices = sorted(set(compute_indices))
else:
compute_indices = all_ep_indices
frame_results = {}
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
try:
sample = dataset[qi]
batch = {
image_key: sample[image_key],
"task": task,
"index": qi,
"episode_index": ep_idx,
}
if state_key in sample:
batch[state_key] = sample[state_key]
with torch.no_grad():
processed = preprocess(batch)
vf = processed["video_features"].to(self.device)
tf = processed["text_features"].to(self.device)
sf = processed.get("state_features")
if sf is not None:
sf = sf.to(self.device)
lengths = processed.get("lengths")
sparse_val = dense_val = np.nan
if compute_sparse:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="sparse",
)
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
if compute_dense:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="dense",
)
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
frame_results[qi] = (sparse_val, dense_val)
except Exception as e:
logging.warning(f"Failed frame {qi}: {e}")
if not frame_results:
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
continue
# Interpolate to all frames in this episode
computed_idx = np.array(sorted(frame_results.keys()))
all_frame_arr = np.arange(ep_start, ep_end)
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
if self.stride > 1 and len(computed_idx) > 1:
if compute_sparse:
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
if compute_dense:
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
output_frames = all_frame_arr
else:
# Use only successfully computed frames to avoid indexing mismatch on failures
output_frames = computed_idx
for i, fi in enumerate(output_frames):
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
if compute_sparse:
row["progress_sparse"] = float(sparse_vals[i])
if compute_dense:
row["progress_dense"] = float(dense_vals[i])
all_rows.append(row)
if all_rows:
import pandas as pd
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
shard_dir = Path(self.shard_dir)
shard_dir.mkdir(parents=True, exist_ok=True)
out = shard_dir / f"shard_{rank:05d}.parquet"
pq.write_table(table, out)
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
class AggregateProgress(PipelineStep):
"""Merge all shard parquets into final sarm_progress.parquet."""
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
super().__init__()
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.shard_dir = shard_dir
self.push_to_hub = push_to_hub
def run(self, data=None, rank: int = 0, world_size: int = 1):
import datetime
import logging
import os
from pathlib import Path
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.utils import init_logging
init_logging()
if rank != 0:
return
shard_dir = Path(self.shard_dir)
shards = sorted(shard_dir.glob("shard_*.parquet"))
if not shards:
raise FileNotFoundError(f"No shards found in {shard_dir}")
# Log shard modification time range to help detect stale files
mtimes = [os.path.getmtime(s) for s in shards]
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
df = df.sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
out_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out_path)
logging.info(f"Saved {len(df)} rows to {out_path}")
for col in ["progress_sparse", "progress_dense"]:
if col in df.columns:
v = df[col].dropna()
logging.info(
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
)
if self.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = "sarm_progress.parquet"
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(out_path),
path_in_repo=hub_path,
repo_id=self.repo_id,
repo_type="dataset",
)
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
def make_compute_executor(
repo_id,
reward_model_path,
stride,
head_mode,
device,
shard_dir,
logs_dir,
job_name,
slurm,
workers,
partition,
cpus_per_task,
mem_per_cpu,
):
kwargs = {
"pipeline": [
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": workers,
"workers": workers,
"time": "24:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": workers, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def make_aggregate_executor(
repo_id,
reward_model_path,
shard_dir,
logs_dir,
job_name,
slurm,
partition,
cpus_per_task,
mem_per_cpu,
push_to_hub,
):
kwargs = {
"pipeline": [
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": 1,
"workers": 1,
"time": "02:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": 1, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def _add_shared_args(p):
p.add_argument(
"--repo-id",
type=str,
required=True,
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
)
p.add_argument(
"--shard-dir",
type=Path,
default=Path("rabc_shards"),
help="Directory to read/write per-rank parquet shards.",
)
p.add_argument(
"--logs-dir",
type=Path,
default=Path("logs"),
help="Directory for datatrove logs.",
)
p.add_argument(
"--job-name",
type=str,
default=None,
help="SLURM job name (defaults to rabc_<subcommand>).",
)
p.add_argument(
"--slurm",
type=int,
default=1,
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
)
p.add_argument(
"--partition",
type=str,
default=None,
help="SLURM partition to submit to.",
)
p.add_argument(
"--cpus-per-task",
type=int,
default=4,
help="Number of CPUs per SLURM task.",
)
p.add_argument(
"--mem-per-cpu",
type=str,
default="4G",
help="Memory per CPU, e.g. '4G' or '1950M'.",
)
def main():
parser = argparse.ArgumentParser(
description="SLURM-distributed SARM RA-BC annotation pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
sub = parser.add_subparsers(dest="command", required=True)
# compute subcommand
cp = sub.add_parser(
"compute",
help="Distribute progress computation across SLURM workers.",
)
_add_shared_args(cp)
cp.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model.",
)
cp.add_argument(
"--stride",
type=int,
default=1,
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
)
cp.add_argument(
"--head-mode",
type=str,
default="sparse",
choices=["sparse", "dense", "both"],
help="Which reward head(s) to compute.",
)
cp.add_argument(
"--device",
type=str,
default="cpu",
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
)
cp.add_argument(
"--workers",
type=int,
default=50,
help="Number of parallel SLURM tasks (one shard per worker).",
)
# aggregate subcommand
ap = sub.add_parser(
"aggregate",
help="Merge per-rank shards into a single sarm_progress.parquet.",
)
_add_shared_args(ap)
ap.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
)
ap.add_argument(
"--push-to-hub",
action="store_true",
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
)
args = parser.parse_args()
job_name = args.job_name or f"rabc_{args.command}"
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
kwargs["job_name"] = job_name
command = kwargs.pop("command")
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
executor.run()
if __name__ == "__main__":
main()
+10 -10
View File
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
Usage:
# Basic usage with smolvla policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
@@ -58,16 +58,16 @@ Usage:
--device=cuda
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/reuben_pi0 \
--dataset.repo_id=<USER>/so101_cube_in_cup \
--policy.path=lipsop/reuben_pi0 \
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
@@ -75,8 +75,8 @@ Usage:
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
@@ -84,8 +84,8 @@ Usage:
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--dataset.repo_id=<USER>/check_rtc \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
+3 -3
View File
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
@@ -41,7 +41,7 @@ Usage:
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
@@ -53,7 +53,7 @@ Usage:
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/pi05_check_rtc \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
+122 -59
View File
@@ -25,11 +25,11 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.5"
version = "0.4.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
requires-python = ">=3.12"
requires-python = ">=3.10"
authors = [
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
@@ -50,8 +50,7 @@ classifiers = [
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.10",
"Topic :: Software Development :: Build Tools",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
@@ -60,30 +59,28 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Hugging Face dependencies
"datasets>=4.0.0,<5.0.0",
"datasets>=4.0.0,<4.2.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub>=1.0.0,<2.0.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"setuptools>=71.0.0,<81.0.0",
"cmake>=3.29.0.1,<4.2.0",
"packaging>=24.2,<26.0",
"torch>=2.2.1,<2.11.0",
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0,<0.26.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"pynput>=1.7.8,<1.9.0",
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0",
"draccus==0.10.0", # TODO: Relax version constraint
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.1.1,<2.0.0",
"rerun-sdk>=0.24.0,<0.27.0",
@@ -98,20 +95,14 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
damiao = ["lerobot[can-dep]"]
robstride = ["lerobot[can-dep]"]
damiao = ["python-can>=4.2.0,<5.0.0"]
# Robots
openarms = ["lerobot[damiao]"]
@@ -123,31 +114,30 @@ unitree_g1 = [
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"lerobot[matplotlib-dep]",
"matplotlib>=3.9.0,<4.0.0",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
wallx = [
"lerobot[transformers-dep]",
"lerobot[peft]",
"lerobot[scipy-dep]",
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
"transformers==4.49.0",
"peft==0.17.1",
"scipy==1.15.3",
"torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11"
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"lerobot[peft]",
"peft>=0.13.0,<1.0.0",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
@@ -156,13 +146,13 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -170,19 +160,13 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
metaworld = ["metaworld==3.0.0"]
# All
all = [
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
# helps pip's resolver converge by constraining scipy early, before it encounters
# the loose scipy requirements from transitive deps like dm-control and metaworld.
"scipy>=1.14.0,<2.0.0",
"lerobot[dynamixel]",
"lerobot[gamepad]",
"lerobot[hopejr]",
@@ -190,8 +174,8 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[wallx]",
"lerobot[pi]",
# "lerobot[wallx]",
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
@@ -203,7 +187,7 @@ all = [
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[phone]",
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[libero]",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[peft]",
@@ -228,14 +212,11 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
[tool.ruff]
target-version = "py312"
target-version = "py310"
line-length = 110
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
@@ -327,7 +308,7 @@ default.extend-ignore-identifiers-re = [
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
[tool.mypy]
python_version = "3.12"
python_version = "3.10"
ignore_missing_imports = true
follow_imports = "skip"
# warn_return_any = true
@@ -379,9 +360,9 @@ ignore_errors = false
module = "lerobot.cameras.*"
ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.motors.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"
@@ -411,3 +392,85 @@ ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"
# ignore_errors = false
[tool.uv]
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
conflicts = [
[
{ extra = "wallx" },
{ extra = "transformers-dep" },
],
[
{ extra = "wallx" },
{ extra = "pi" },
],
[
{ extra = "wallx" },
{ extra = "smolvla" },
],
[
{ extra = "wallx" },
{ extra = "groot" },
],
[
{ extra = "wallx" },
{ extra = "xvla" },
],
[
{ extra = "wallx" },
{ extra = "sarm" },
],
[
{ extra = "wallx" },
{ extra = "hilserl" },
],
[
{ extra = "wallx" },
{ extra = "libero" },
],
[
{ extra = "wallx" },
{ extra = "peft" },
],
[
{ extra = "wallx" },
{ extra = "all" },
],
# pi uses custom branch which conflicts with transformers-dep
[
{ extra = "pi" },
{ extra = "transformers-dep" },
],
[
{ extra = "pi" },
{ extra = "smolvla" },
],
[
{ extra = "pi" },
{ extra = "groot" },
],
[
{ extra = "pi" },
{ extra = "xvla" },
],
[
{ extra = "pi" },
{ extra = "sarm" },
],
[
{ extra = "pi" },
{ extra = "hilserl" },
],
[
{ extra = "pi" },
{ extra = "libero" },
],
[
{ extra = "pi" },
{ extra = "peft" },
],
[
{ extra = "pi" },
{ extra = "all" },
],
]
+3 -5
View File
@@ -63,9 +63,9 @@ from lerobot.transport import (
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
from lerobot.utils.import_utils import register_third_party_plugins
from .configs import RobotClientConfig
from .constants import SUPPORTED_ROBOTS
from .helpers import (
Action,
FPSTracker,
@@ -485,9 +485,8 @@ class RobotClient:
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
# TODO: Assert if checking robot support is still needed with the plugin system
# if cfg.robot.type not in SUPPORTED_ROBOTS:
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
@@ -513,5 +512,4 @@ def async_client(cfg: RobotClientConfig):
if __name__ == "__main__":
register_third_party_plugins()
async_client() # run the client
+1 -1
View File
@@ -13,5 +13,5 @@
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .utils import make_cameras_from_configs
+1 -1
View File
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
"""
pass
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
-23
View File
@@ -25,10 +25,6 @@ class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
@@ -36,25 +32,6 @@ class Cv2Rotation(int, Enum):
ROTATE_180 = 180
ROTATE_270 = -90
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
class Cv2Backends(int, Enum):
ANY = 0
V4L2 = 200
DSHOW = 700
PVAPI = 800
ANDROID = 1000
AVFOUNDATION = 1200
MSMF = 1400
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
+15 -10
View File
@@ -32,11 +32,10 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_rotation
from ..utils import get_cv2_backend, get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
@@ -118,7 +117,7 @@ class OpenCVCamera(Camera):
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = config.backend
self.backend: int = get_cv2_backend()
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
@@ -133,7 +132,6 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
@@ -150,6 +148,8 @@ class OpenCVCamera(Camera):
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
@@ -178,7 +178,6 @@ class OpenCVCamera(Camera):
logger.info(f"{self} connected.")
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
@@ -198,6 +197,8 @@ class OpenCVCamera(Camera):
to the requested value.
DeviceNotConnectedError: If the camera is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
@@ -347,7 +348,6 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -374,6 +374,9 @@ class OpenCVCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -487,7 +490,6 @@ class OpenCVCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -510,6 +512,8 @@ class OpenCVCamera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -529,8 +533,7 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -545,6 +548,8 @@ class OpenCVCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -15,9 +15,9 @@
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from ..configs import CameraConfig, ColorMode, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
@CameraConfig.register_subclass("opencv")
@@ -50,7 +50,6 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
@@ -63,12 +62,22 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
backend: Cv2Backends = Cv2Backends.ANY
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
self.backend = Cv2Backends(self.backend)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(
@@ -74,4 +74,7 @@ class Reachy2CameraConfig(CameraConfig):
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
@@ -32,7 +32,6 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
@@ -124,7 +123,6 @@ class Reachy2Camera(Camera):
"""
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -138,6 +136,9 @@ class Reachy2Camera(Camera):
"""
start_time = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -183,7 +184,6 @@ class Reachy2Camera(Camera):
return frame
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Same as read()
@@ -197,11 +197,12 @@ class Reachy2Camera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
return self.read()
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -218,6 +219,8 @@ class Reachy2Camera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
@@ -230,7 +233,6 @@ class Reachy2Camera(Camera):
return self.latest_frame
@check_if_not_connected
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -238,6 +240,8 @@ class Reachy2Camera(Camera):
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.cam_manager is not None:
self.cam_manager.disconnect()
@@ -30,8 +30,7 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -153,7 +152,6 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -171,6 +169,8 @@ class RealSenseCamera(Camera):
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
@@ -290,7 +290,6 @@ class RealSenseCamera(Camera):
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
@@ -300,6 +299,8 @@ class RealSenseCamera(Camera):
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
@@ -319,7 +320,6 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
@check_if_not_connected
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -345,6 +345,9 @@ class RealSenseCamera(Camera):
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -371,7 +374,6 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
@@ -401,6 +403,9 @@ class RealSenseCamera(Camera):
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -529,7 +534,6 @@ class RealSenseCamera(Camera):
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -552,6 +556,8 @@ class RealSenseCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -572,8 +578,7 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
This method is non-blocking and returns whatever is currently in the
@@ -588,6 +593,8 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -60,8 +60,20 @@ class RealSenseCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):
+12
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from typing import cast
from lerobot.utils.import_utils import make_device_from_device_class
@@ -67,3 +68,14 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)
+10 -6
View File
@@ -34,8 +34,7 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -105,7 +104,6 @@ class ZMQCamera(Camera):
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server.
@@ -113,6 +111,8 @@ class ZMQCamera(Camera):
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
logger.info(f"Connecting to {self}...")
@@ -211,7 +211,6 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -229,6 +228,9 @@ class ZMQCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -299,7 +301,6 @@ class ZMQCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -316,6 +317,8 @@ class ZMQCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -332,7 +335,6 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
@@ -348,6 +350,8 @@ class ZMQCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
+4 -1
View File
@@ -32,7 +32,10 @@ class ZMQCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
+1 -1
View File
@@ -27,7 +27,7 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
+6 -6
View File
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
+1 -3
View File
@@ -289,9 +289,7 @@ def aggregate_datasets(
logging.info("Find all tasks")
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
dst_meta.tasks = pd.DataFrame(
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
)
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
meta_idx = {"chunk": 0, "file": 0}
data_idx = {"chunk": 0, "file": 0}
-7
View File
@@ -7,13 +7,6 @@
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
{% if repo_id is defined and repo_id %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
## Dataset Description
{{ dataset_description | default("", true) }}
+33 -41
View File
@@ -89,8 +89,8 @@ def delete_episodes(
Args:
dataset: The source LeRobotDataset.
episode_indices: List of episode indices to delete.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
"""
if not episode_indices:
raise ValueError("No episodes to delete")
@@ -152,7 +152,7 @@ def split_dataset(
dataset: The source LeRobotDataset to split.
splits: Either a dict mapping split names to episode indices, or a dict mapping
split names to fractions (must sum to <= 1.0).
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
output_dir: Base directory for output datasets. If None, uses default location.
Examples:
Split by specific episodes
@@ -243,8 +243,8 @@ def merge_datasets(
Args:
datasets: List of LeRobotDatasets to merge.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
output_repo_id: Repository ID for the merged dataset.
output_dir: Directory to save the merged dataset. If None, uses default location.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -288,8 +288,8 @@ def modify_features(
dataset: The source LeRobotDataset.
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
remove_features: Optional feature name(s) to remove. Can be a single string or list.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
New dataset with features modified.
@@ -390,8 +390,8 @@ def add_features(
Args:
dataset: The source LeRobotDataset.
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
New dataset with all features added.
@@ -427,8 +427,8 @@ def remove_feature(
Args:
dataset: The source LeRobotDataset.
feature_names: Name(s) of features to remove. Can be a single string or list.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
Returns:
New dataset with features removed.
@@ -567,22 +567,20 @@ def _copy_and_reindex_data(
def _keep_episodes_from_video_with_av(
input_path: Path,
output_path: Path,
episodes_to_keep: list[tuple[int, int]],
episodes_to_keep: list[tuple[float, float]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> None:
"""Keep only specified episodes from a video file using PyAV.
This function decodes frames from specified frame ranges and re-encodes them with
This function decodes frames from specified time ranges and re-encodes them with
properly reset timestamps to ensure monotonic progression.
Args:
input_path: Source video file path.
output_path: Destination video file path.
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
@@ -624,10 +622,9 @@ def _keep_episodes_from_video_with_av(
# Create set of (start, end) ranges for fast lookup.
# Convert to a sorted list for efficient checking.
frame_ranges = sorted(episodes_to_keep)
time_ranges = sorted(episodes_to_keep)
# Track frame index for setting PTS and current range being processed.
src_frame_count = 0
frame_count = 0
range_idx = 0
@@ -637,20 +634,21 @@ def _keep_episodes_from_video_with_av(
if frame is None:
continue
# Check if frame is in any of our desired frame ranges.
# Get frame timestamp.
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
# Check if frame is in any of our desired time ranges.
# Skip ranges that have already passed.
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
range_idx += 1
# If we've passed all ranges, stop processing.
if range_idx >= len(frame_ranges):
if range_idx >= len(time_ranges):
break
# Check if frame is in current range.
start_frame = frame_ranges[range_idx][0]
if src_frame_count < start_frame:
src_frame_count += 1
start_ts, end_ts = time_ranges[range_idx]
if frame_time < start_ts:
continue
# Frame is in range - create a new frame with reset timestamps.
@@ -663,7 +661,6 @@ def _keep_episodes_from_video_with_av(
for pkt in v_out.encode(new_frame):
out.mux(pkt)
src_frame_count += 1
frame_count += 1
# Flush encoder.
@@ -752,17 +749,15 @@ def _copy_and_reindex_videos(
f"videos/{video_key}/to_timestamp"
]
else:
# Build list of frame ranges to keep, in sorted order.
# Build list of time ranges to keep, in sorted order.
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
episodes_to_keep_ranges: list[tuple[int, int]] = []
episodes_to_keep_ranges: list[tuple[float, float]] = []
for old_idx in sorted_keep_episodes:
src_ep = src_dataset.meta.episodes[old_idx]
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
assert src_ep["length"] == to_frame - from_frame, (
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
)
episodes_to_keep_ranges.append((from_frame, to_frame))
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
episodes_to_keep_ranges.append((from_ts, to_ts))
# Use PyAV filters to efficiently re-encode only the desired segments.
assert src_dataset.meta.video_path is not None
@@ -1475,9 +1470,7 @@ def modify_tasks(
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame(
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
)
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
@@ -1531,7 +1524,7 @@ def modify_tasks(
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path | None = None,
output_dir: Path,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
@@ -1550,8 +1543,8 @@ def convert_image_to_video_dataset(
Args:
dataset: The source LeRobot dataset with images
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
@@ -1602,7 +1595,6 @@ def convert_image_to_video_dataset(
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
+29 -126
View File
@@ -68,7 +68,6 @@ from lerobot.datasets.utils import (
write_tasks,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
VideoFrame,
concatenate_video_files,
decode_video_frames,
@@ -76,11 +75,11 @@ from lerobot.datasets.video_utils import (
get_safe_default_codec,
get_video_duration_in_s,
get_video_info,
resolve_vcodec,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
class LeRobotDatasetMetadata:
@@ -314,7 +313,7 @@ class LeRobotDatasetMetadata:
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
@@ -546,19 +545,12 @@ class LeRobotDatasetMetadata:
def _encode_video_worker(
video_key: str,
episode_index: int,
root: Path,
fps: int,
vcodec: str = "libsvtav1",
encoder_threads: int | None = None,
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
)
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
shutil.rmtree(img_dir)
return temp_path
@@ -578,9 +570,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -664,11 +653,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for the README).
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory where the dataset will be downloaded and
stored. If set, all dataset files will be stored directly under this path. If not set, the
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
HF_LEROBOT_HOME environment variable).
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
@@ -694,17 +683,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
streaming encoding. Defaults to 30 (~1s at 30fps).
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
libsvtav1 and 'threads' for h264/hevc.
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
encoding is CPU-heavy.
"""
super().__init__()
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
@@ -716,8 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_indices = None
self.batch_encoding_size = batch_encoding_size
self.episodes_since_last_encoding = 0
self.vcodec = resolve_vcodec(vcodec)
self._encoder_threads = encoder_threads
self.vcodec = vcodec
# Unused attributes
self.image_writer = None
@@ -725,7 +708,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.writer = None
self.latest_episode = None
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
self._streaming_encoder = None
self.root.mkdir(exist_ok=True, parents=True)
@@ -747,7 +729,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Check if cached dataset contains all requested episodes
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (FileNotFoundError, NotADirectoryError):
except (AssertionError, FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
@@ -767,19 +749,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Initialize streaming encoder for resumed recording
if streaming_encoding and len(self.meta.video_keys) > 0:
self._streaming_encoder = StreamingVideoEncoder(
fps=self.meta.fps,
vcodec=self.vcodec,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
writer = getattr(self, "writer", None)
@@ -839,7 +808,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hub_api.upload_folder(**upload_kwargs)
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
@@ -1135,8 +1104,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
self._close_writer()
self.meta._close_writer()
if self._streaming_encoder is not None:
self._streaming_encoder.close()
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
@@ -1191,13 +1158,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
# Start streaming encoder on first frame of episode (once, before iterating keys)
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
video_keys=list(self.meta.video_keys),
temp_dir=self.root,
)
# Add frame features to episode_buffer
for key in frame:
if key not in self.features:
@@ -1205,10 +1165,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
)
if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
self._streaming_encoder.feed_frame(key, frame[key])
self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet)
elif self.features[key]["dtype"] in ["image", "video"]:
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
@@ -1269,38 +1226,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
has_video_keys = len(self.meta.video_keys) > 0
use_streaming = self._streaming_encoder is not None and has_video_keys
use_batched_encoding = self.batch_encoding_size > 1
if use_streaming:
# Compute stats for non-video features only (video stats come from encoder)
non_video_buffer = {
k: v
for k, v in episode_buffer.items()
if self.features.get(k, {}).get("dtype") not in ("video",)
}
non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"}
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
else:
ep_stats = compute_episode_stats(episode_buffer, self.features)
ep_stats = compute_episode_stats(episode_buffer, self.features)
ep_metadata = self._save_episode_data(episode_buffer)
has_video_keys = len(self.meta.video_keys) > 0
use_batched_encoding = self.batch_encoding_size > 1
if use_streaming:
# Finish streaming encoding and collect results
streaming_results = self._streaming_encoder.finish_episode()
for video_key in self.meta.video_keys:
temp_path, video_stats = streaming_results[video_key]
if video_stats is not None:
# Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1)
ep_stats[video_key] = {
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
elif has_video_keys and not use_batched_encoding:
if has_video_keys and not use_batched_encoding:
num_cameras = len(self.meta.video_keys)
if parallel_encoding and num_cameras > 1:
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
@@ -1314,7 +1246,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root,
self.fps,
self.vcodec,
self._encoder_threads,
): video_key
for video_key in self.meta.video_keys
}
@@ -1583,10 +1514,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return metadata
def clear_episode_buffer(self, delete_images: bool = True) -> None:
# Cancel streaming encoder if active
if self._streaming_encoder is not None:
self._streaming_encoder.cancel_episode()
# Clean up image files for the current episode buffer
if delete_images:
# Wait for the async image writer to finish
@@ -1634,9 +1561,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
return _encode_video_worker(
video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads
)
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
@classmethod
def create(
@@ -1653,13 +1578,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
vcodec = resolve_vcodec(vcodec)
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -1668,7 +1590,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
features=features,
root=root,
use_videos=use_videos,
metadata_buffer_size=metadata_buffer_size,
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
@@ -1678,7 +1599,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0
obj.vcodec = vcodec
obj._encoder_threads = encoder_threads
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
@@ -1700,22 +1620,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._lazy_loading = False
obj._recorded_frames = 0
obj._writer_closed_for_reading = False
# Initialize streaming encoder
if streaming_encoding and len(obj.meta.video_keys) > 0:
obj._streaming_encoder = StreamingVideoEncoder(
fps=fps,
vcodec=vcodec,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
queue_maxsize=encoder_queue_maxsize,
encoder_threads=encoder_threads,
)
else:
obj._streaming_encoder = None
return obj
@@ -1771,12 +1675,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
if extra_keys:
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
+9 -10
View File
@@ -216,17 +216,16 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "SharpnessJitter":
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
class ImageTransforms(Transform):
+17 -6
View File
@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any
from typing import Any, Generic, TypeVar
import datasets
import numpy as np
@@ -78,6 +78,8 @@ DEFAULT_FEATURES = {
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
T = TypeVar("T")
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
metadata = pq.read_metadata(parquet_path)
@@ -120,9 +122,19 @@ def load_nested_dataset(
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
with SuppressProgressBars():
# We use .from_parquet() memory-mapped loading for efficiency
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
@@ -339,7 +351,6 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
tasks.index.name = "task"
return tasks
@@ -1232,7 +1243,7 @@ class LookAheadError(Exception):
pass
class Backtrackable[T]:
class Backtrackable(Generic[T]):
"""
Wrap any iterator/iterable so you can step back up to `history` items
and look ahead up to `lookahead` items.
@@ -36,11 +36,8 @@ Convert a local dataset (works in place):
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht \
--root=/path/to/local/dataset/directory \
--root=/path/to/local/dataset/directory
--push-to-hub=false
N.B. Path semantics (v2): --root is the exact dataset folder containing
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
```
"""
@@ -108,7 +105,7 @@ episodes.jsonl
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
NEW
meta/episodes/chunk-000/file_000.parquet
meta/episodes/chunk-000/episodes_000.parquet
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
-------------------------
OLD
@@ -116,16 +113,15 @@ tasks.jsonl
{"task_index": 1, "task": "Put the blue block in the green bowl"}
NEW
meta/tasks.parquet
meta/tasks/chunk-000/file_000.parquet
task_index | task
-------------------------
OLD
episodes_stats.jsonl
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
NEW
meta/episodes/chunk-000/file_000.parquet
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
meta/episodes_stats/chunk-000/file_000.parquet
episode_index | mean | std | min | max
-------------------------
UPDATE
meta/info.json
@@ -174,7 +170,7 @@ def convert_tasks(root, new_root):
tasks, _ = legacy_load_tasks(root)
task_indices = tasks.keys()
task_strings = tasks.values()
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
write_tasks(df_tasks, new_root)
@@ -205,6 +201,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
image_keys = get_image_keys(root)
ep_idx = 0
chunk_idx = 0
file_idx = 0
size_in_mb = 0
@@ -214,23 +211,9 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
logging.info(f"Converting data files from {len(ep_paths)} episodes")
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
# Check if we need to start a new file BEFORE creating metadata
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
# Write the accumulated data files
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Move to next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Reset for the next file
size_in_mb = 0
paths_to_cat = []
# Now create metadata with correct chunk/file indices
ep_metadata = {
"episode_index": ep_idx,
"data/chunk_index": chunk_idx,
@@ -241,7 +224,20 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
size_in_mb += ep_size_in_mb
num_frames += ep_num_frames
episodes_metadata.append(ep_metadata)
paths_to_cat.append(ep_path)
ep_idx += 1
if size_in_mb < data_file_size_in_mb:
paths_to_cat.append(ep_path)
continue
if paths_to_cat:
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
# Reset for the next file
size_in_mb = ep_size_in_mb
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
# Write remaining data if any
if paths_to_cat:
@@ -473,7 +469,7 @@ def convert_dataset(
# Set root based on whether local dataset path is provided
use_local_dataset = False
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
if root.exists():
validate_local_dataset_version(root)
use_local_dataset = True
@@ -533,7 +529,7 @@ if __name__ == "__main__":
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
@@ -557,7 +553,7 @@ if __name__ == "__main__":
"--root",
type=str,
default=None,
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
help="Local directory to use for downloading/writing the dataset.",
)
parser.add_argument(
"--push-to-hub",
+46 -480
View File
@@ -13,106 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import glob
import importlib
import logging
import queue
import shutil
import tempfile
import threading
import warnings
from dataclasses import dataclass, field
from fractions import Fraction
from pathlib import Path
from threading import Lock
from typing import Any, ClassVar
import av
import fsspec
import numpy as np
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
"h264_videotoolbox", # macOS
"hevc_videotoolbox", # macOS
"h264_nvenc", # NVIDIA GPU
"hevc_nvenc", # NVIDIA GPU
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
def _get_codec_options(
vcodec: str,
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
) -> dict:
"""Build codec-specific options dict for video encoding."""
options = {}
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
options["g"] = str(g)
# Quality control (codec-specific parameter names)
if crf is not None:
if vcodec in ("h264", "hevc", "libsvtav1"):
options["crf"] = str(crf)
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
quality = max(1, min(100, int(100 - crf * 2)))
options["q:v"] = str(quality)
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
options["rc"] = "constqp"
options["qp"] = str(crf)
elif vcodec in ("h264_vaapi",):
options["qp"] = str(crf)
elif vcodec in ("h264_qsv",):
options["global_quality"] = str(crf)
# Preset (only for libsvtav1)
if vcodec == "libsvtav1":
options["preset"] = str(preset) if preset is not None else "12"
return options
def detect_available_hw_encoders() -> list[str]:
"""Probe PyAV/FFmpeg for available hardware video encoders."""
available = []
for codec_name in HW_ENCODERS:
try:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
pass # nosec B110
return available
def resolve_vcodec(vcodec: str) -> str:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logging.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logging.info(f"Auto-selected video codec: {encoder}")
return encoder
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
@@ -227,17 +146,16 @@ def decode_video_frames_torchvision(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -249,11 +167,7 @@ def decode_video_frames_torchvision(
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
if len(timestamps) != len(closest_frames):
raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})"
)
assert len(timestamps) == len(closest_frames)
return closest_frames
@@ -358,16 +272,15 @@ def decode_video_frames_torchcodec(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -396,13 +309,14 @@ def encode_video_frames(
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
log_level: int | None = av.logging.WARNING,
log_level: int | None = av.logging.ERROR,
overwrite: bool = False,
preset: int | None = None,
encoder_threads: int | None = None,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
vcodec = resolve_vcodec(vcodec)
# Check encoder availability
if vcodec not in ["h264", "hevc", "libsvtav1"]:
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
@@ -433,22 +347,21 @@ def encode_video_frames(
width, height = dummy_image.size
# Define video codec options
video_options = _get_codec_options(vcodec, g, crf, preset)
video_options = {}
if g is not None:
video_options["g"] = str(g)
if crf is not None:
video_options["crf"] = str(crf)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if encoder_threads is not None:
if vcodec == "libsvtav1":
lp_param = f"lp={encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(encoder_threads)
if vcodec == "libsvtav1":
video_options["preset"] = str(preset) if preset is not None else "12"
# Set logging level
if log_level is not None:
@@ -567,348 +480,6 @@ def concatenate_video_files(
Path(tmp_concatenate_path).unlink()
class _CameraEncoderThread(threading.Thread):
"""A thread that encodes video frames streamed via a queue into an MP4 file.
One instance is created per camera per episode. Frames are received as numpy arrays
from the main thread, encoded in real-time using PyAV (which releases the GIL during
encoding), and written to disk. Stats are computed incrementally using
RunningQuantileStats and returned via result_queue.
"""
def __init__(
self,
video_path: Path,
fps: int,
vcodec: str,
pix_fmt: str,
g: int | None,
crf: int | None,
preset: int | None,
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
encoder_threads: int | None = None,
):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.encoder_threads = encoder_threads
def run(self) -> None:
from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
container = None
output_stream = None
stats_tracker = RunningQuantileStats()
frame_count = 0
try:
logging.getLogger("libav").setLevel(av.logging.WARNING)
while True:
try:
frame_data = self.frame_queue.get(timeout=1)
except queue.Empty:
if self.stop_event.is_set():
break
continue
if frame_data is None:
# Sentinel: flush and close
break
# Ensure HWC uint8 numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if frame_data.dtype != np.uint8:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
if container is None:
height, width = frame_data.shape[:2]
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
if self.encoder_threads is not None:
if self.vcodec == "libsvtav1":
lp_param = f"lp={self.encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(self.encoder_threads)
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
output_stream.pix_fmt = self.pix_fmt
output_stream.width = width
output_stream.height = height
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
if packet:
container.mux(packet)
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
img_downsampled = auto_downsample_height_width(img_chw)
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled.shape[0]
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
stats_tracker.update(img_for_stats)
frame_count += 1
# Flush encoder
if output_stream is not None:
packet = output_stream.encode()
if packet:
container.mux(packet)
if container is not None:
container.close()
av.logging.restore_default_callback()
# Get stats and put on result queue
if frame_count >= 2:
stats = stats_tracker.get_statistics()
self.result_queue.put(("ok", stats))
else:
self.result_queue.put(("ok", None))
except Exception as e:
logging.error(f"Encoder thread error: {e}")
if container is not None:
with contextlib.suppress(Exception):
container.close()
self.result_queue.put(("error", str(e)))
class StreamingVideoEncoder:
"""Manages per-camera encoder threads for real-time video encoding during recording.
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
this class streams frames directly to encoder threads, eliminating the
PNG round-trip and making save_episode() near-instant.
Uses threading instead of multiprocessing to avoid the overhead of pickling large
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
so encoding runs in parallel with the main recording loop.
"""
def __init__(
self,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
self.fps = fps
self.vcodec = resolve_vcodec(vcodec)
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self.queue_maxsize = queue_maxsize
self.encoder_threads = encoder_threads
self._frame_queues: dict[str, queue.Queue] = {}
self._result_queues: dict[str, queue.Queue] = {}
self._threads: dict[str, _CameraEncoderThread] = {}
self._stop_events: dict[str, threading.Event] = {}
self._video_paths: dict[str, Path] = {}
self._dropped_frames: dict[str, int] = {}
self._episode_active = False
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
"""Start encoder threads for a new episode.
Args:
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
temp_dir: Base directory for temporary MP4 files
"""
if self._episode_active:
self.cancel_episode()
self._dropped_frames.clear()
for video_key in video_keys:
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
result_queue: queue.Queue = queue.Queue(maxsize=1)
stop_event = threading.Event()
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
vcodec=self.vcodec,
pix_fmt=self.pix_fmt,
g=self.g,
crf=self.crf,
preset=self.preset,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
encoder_threads=self.encoder_threads,
)
encoder_thread.start()
self._frame_queues[video_key] = frame_queue
self._result_queues[video_key] = result_queue
self._threads[video_key] = encoder_thread
self._stop_events[video_key] = stop_event
self._video_paths[video_key] = video_path
self._episode_active = True
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
"""Feed a frame to the encoder for a specific camera.
A copy of the image is made before enqueueing to prevent race conditions
with camera drivers that may reuse buffers. If the encoder queue is full
(encoder can't keep up), the frame is dropped with a warning instead of
crashing the recording session.
Args:
video_key: The video feature key
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
Raises:
RuntimeError: If the encoder thread has crashed
"""
if not self._episode_active:
raise RuntimeError("No active episode. Call start_episode() first.")
thread = self._threads[video_key]
if not thread.is_alive():
# Check for error
try:
status, msg = self._result_queues[video_key].get_nowait()
if status == "error":
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
except queue.Empty:
pass
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
try:
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
except queue.Full:
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
count = self._dropped_frames[video_key]
# Log periodically to avoid spam (1st, then every 10th)
if count == 1 or count % 10 == 0:
logging.warning(
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
)
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
"""Finish encoding the current episode.
Sends sentinel values, waits for encoder threads to complete,
and collects results.
Returns:
Dict mapping video_key to (mp4_path, stats_dict_or_None)
"""
if not self._episode_active:
raise RuntimeError("No active episode to finish.")
results = {}
# Report dropped frames
for video_key, count in self._dropped_frames.items():
if count > 0:
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
# Send sentinel to all queues
for video_key in self._frame_queues:
self._frame_queues[video_key].put(None)
# Wait for all threads and collect results
for video_key in self._threads:
self._threads[video_key].join(timeout=120)
if self._threads[video_key].is_alive():
logging.error(f"Encoder thread for {video_key} did not finish in time")
self._stop_events[video_key].set()
self._threads[video_key].join(timeout=5)
results[video_key] = (self._video_paths[video_key], None)
continue
try:
status, data = self._result_queues[video_key].get(timeout=5)
if status == "error":
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
results[video_key] = (self._video_paths[video_key], data)
except queue.Empty:
logging.error(f"No result from encoder thread for {video_key}")
results[video_key] = (self._video_paths[video_key], None)
self._cleanup()
self._episode_active = False
return results
def cancel_episode(self) -> None:
"""Cancel the current episode, stopping encoder threads and cleaning up."""
if not self._episode_active:
return
# Signal all threads to stop
for video_key in self._stop_events:
self._stop_events[video_key].set()
# Wait for threads to finish
for video_key in self._threads:
self._threads[video_key].join(timeout=5)
# Clean up temp MP4 files
video_path = self._video_paths.get(video_key)
if video_path is not None and video_path.exists():
shutil.rmtree(str(video_path.parent), ignore_errors=True)
self._cleanup()
self._episode_active = False
def close(self) -> None:
"""Close the encoder, canceling any in-progress episode."""
if self._episode_active:
self.cancel_episode()
def _cleanup(self) -> None:
"""Clean up queues and thread tracking dicts."""
for q in self._frame_queues.values():
with contextlib.suppress(Exception):
while not q.empty():
q.get_nowait()
self._frame_queues.clear()
self._result_queues.clear()
self._threads.clear()
self._stop_events.clear()
self._video_paths.clear()
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
@@ -943,7 +514,7 @@ with warnings.catch_warnings():
def get_audio_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.WARNING)
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting audio stream information
audio_info = {}
@@ -975,7 +546,7 @@ def get_audio_info(video_path: Path | str) -> dict:
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
logging.getLogger("libav").setLevel(av.logging.WARNING)
logging.getLogger("libav").setLevel(av.logging.ERROR)
# Getting video stream information
video_info = {}
@@ -1061,15 +632,8 @@ class VideoEncodingManager:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
if streaming_encoder is not None:
# Handle streaming encoder cleanup
if exc_type is not None:
streaming_encoder.cancel_episode()
streaming_encoder.close()
elif self.dataset.episodes_since_last_encoding > 0:
# Handle any remaining episodes that haven't been batch encoded
# Handle any remaining episodes that haven't been batch encoded
if self.dataset.episodes_since_last_encoding > 0:
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
@@ -1086,8 +650,8 @@ class VideoEncodingManager:
# Finalize the dataset to properly close all writers
self.dataset.finalize()
# Clean up episode images if recording was interrupted (only for non-streaming mode)
if exc_type is not None and streaming_encoder is None:
# Clean up episode images if recording was interrupted
if exc_type is not None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
@@ -1101,12 +665,14 @@ class VideoEncodingManager:
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"
if img_dir.exists():
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
# Check for any remaining PNG files
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
# Only remove the images directory if no PNG files remain
if img_dir.exists():
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception
-1
View File
@@ -205,7 +205,6 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
+2 -7
View File
@@ -112,7 +112,6 @@ class LiberoEnv(gym.Env):
visualization_height: int = 480,
init_states: bool = True,
episode_index: int = 0,
n_envs: int = 1,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
@@ -146,9 +145,7 @@ class LiberoEnv(gym.Env):
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
@@ -298,8 +295,7 @@ class LiberoEnv(gym.Env):
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
self.init_state_id += self._reset_stride # Change init_state_id when reset
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
@@ -377,7 +373,6 @@ def _make_env_fns(
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
n_envs=n_envs,
control_mode=control_mode,
**local_kwargs,
)
+4 -6
View File
@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(self.groups)
self.group_names = list(groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
@@ -230,20 +230,18 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")
+15 -41
View File
@@ -23,7 +23,6 @@ from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
@@ -37,6 +36,7 @@ else:
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
@@ -155,7 +155,6 @@ class DamiaoMotorsBus(MotorsBusBase):
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
@@ -163,6 +162,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
@@ -208,9 +211,6 @@ class DamiaoMotorsBus(MotorsBusBase):
logger.info("Starting handshake with motors...")
# Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01):
pass
@@ -246,7 +246,6 @@ class DamiaoMotorsBus(MotorsBusBase):
)
logger.info("Handshake successful. All motors ready.")
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
@@ -254,6 +253,8 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
if disable_torque:
try:
@@ -282,10 +283,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
@@ -344,10 +341,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
@@ -363,10 +356,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
CAN message if received, None otherwise
"""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
start_time = time.time()
messages_seen = []
@@ -405,13 +394,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
Dictionary mapping recv_id to CAN message
"""
responses: dict[int, can.Message] = {}
responses = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
@@ -475,9 +461,6 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
@@ -505,9 +488,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
@@ -582,9 +562,10 @@ class DamiaoMotorsBus(MotorsBusBase):
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
@check_if_not_connected
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
@@ -614,7 +595,6 @@ class DamiaoMotorsBus(MotorsBusBase):
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
@check_if_not_connected
def write(
self,
data_name: str,
@@ -625,6 +605,8 @@ class DamiaoMotorsBus(MotorsBusBase):
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
@@ -674,10 +656,6 @@ class DamiaoMotorsBus(MotorsBusBase):
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
@@ -700,12 +678,10 @@ class DamiaoMotorsBus(MotorsBusBase):
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
@check_if_not_connected
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
@@ -714,8 +690,6 @@ class DamiaoMotorsBus(MotorsBusBase):
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
@@ -758,9 +732,9 @@ class DamiaoMotorsBus(MotorsBusBase):
def record_ranges_of_motion(
self,
motors: str | list[str] | None = None,
motors: NameOrID | list[NameOrID] | None = None,
display_values: bool = True,
) -> tuple[dict[str, Value], dict[str, Value]]:
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
+8 -8
View File
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=int(drive_modes[motor]),
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
return {id_: data[0] for id_, data in data_list.items()}
+9 -9
View File
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
return half_turn_homings
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list: dict[int, int] = {}
data_list = {}
status_length = 6
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
+94 -97
View File
@@ -23,13 +23,12 @@ from __future__ import annotations
import abc
import logging
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pprint import pformat
from typing import Protocol
from typing import Protocol, TypeAlias
import serial
from deepdiff import DeepDiff
@@ -38,8 +37,8 @@ from tqdm import tqdm
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.utils import enter_pressed, move_cursor_up
type NameOrID = str | int
type Value = int | float
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
logger = logging.getLogger(__name__)
@@ -94,7 +93,7 @@ class MotorsBusBase(abc.ABC):
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@@ -180,16 +179,15 @@ class Motor:
class PortHandler(Protocol):
is_open: bool
baudrate: int
packet_start_time: float
packet_timeout: float
tx_time_per_byte: float
is_using: bool
port_name: str
ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
def openPort(self): ...
def closePort(self): ...
@@ -242,22 +240,19 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
last_result: bool
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
@@ -270,17 +265,15 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
@@ -407,7 +400,7 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> str:
def _get_motor_model(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
@@ -415,19 +408,17 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return list(self.motors)
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, int):
return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
elif isinstance(motors, list):
return motors.copy()
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
@@ -649,19 +640,18 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors.
Args:
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
"""
pass
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
def torque_disabled(self, motors: int | str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -738,19 +728,24 @@ class SerialMotorsBus(MotorsBusBase):
"""
pass
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared.
Args:
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
for motor in motor_names:
for motor in motors:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
@@ -759,9 +754,7 @@ class SerialMotorsBus(MotorsBusBase):
self.calibration = {}
def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one
@@ -771,12 +764,17 @@ class SerialMotorsBus(MotorsBusBase):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns:
dict[str, Value]: Mapping *motor name written homing offset*.
dict[NameOrID, Value]: Mapping *motor written homing offset*.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
@@ -788,8 +786,8 @@ class SerialMotorsBus(MotorsBusBase):
pass
def record_ranges_of_motion(
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[str, Value], dict[str, Value]]:
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -801,25 +799,30 @@ class SerialMotorsBus(MotorsBusBase):
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns:
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motor_names, normalize=False)
positions = self.sync_read("Present_Position", motors, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motor_names:
for motor in motors:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed():
@@ -827,9 +830,9 @@ class SerialMotorsBus(MotorsBusBase):
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motor_names) + 3)
move_cursor_up(len(motors) + 3)
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -952,12 +955,12 @@ class SerialMotorsBus(MotorsBusBase):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return None
return
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return None
return
return model_number
@@ -1004,13 +1007,12 @@ class SerialMotorsBus(MotorsBusBase):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
decoded = self._decode_sign(data_name, {id_: value})
id_value = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return normalized[id_]
id_value = self._normalize(id_value)
return decoded[id_]
return id_value[id_]
def _read(
self,
@@ -1021,7 +1023,7 @@ class SerialMotorsBus(MotorsBusBase):
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int, int]:
) -> tuple[int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -1071,14 +1073,13 @@ class SerialMotorsBus(MotorsBusBase):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data:
int_value = self._unnormalize({id_: value})[id_]
value = self._unnormalize({id_: value})[id_]
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
@@ -1112,7 +1113,7 @@ class SerialMotorsBus(MotorsBusBase):
def sync_read(
self,
data_name: str,
motors: NameOrID | Sequence[NameOrID] | None = None,
motors: str | list[str] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
@@ -1121,7 +1122,7 @@ class SerialMotorsBus(MotorsBusBase):
Args:
data_name (str): Register name.
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1142,17 +1143,16 @@ class SerialMotorsBus(MotorsBusBase):
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
raw_ids_values, _ = self._sync_read(
ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
decoded = self._decode_sign(data_name, raw_ids_values)
ids_values = self._decode_sign(data_name, ids_values)
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
ids_values = self._normalize(ids_values)
return {self._id_to_name(id_): value for id_, value in decoded.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read(
self,
@@ -1224,24 +1224,21 @@ class SerialMotorsBus(MotorsBusBase):
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in raw_ids_values]
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data:
int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._unnormalize(ids_values)
int_ids_values = self._encode_sign(data_name, int_ids_values)
ids_values = self._encode_sign(data_name, ids_values)
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _sync_write(
self,
@@ -1277,4 +1274,4 @@ class SerialMotorsBus(MotorsBusBase):
# Backward compatibility alias
MotorsBus = SerialMotorsBus
MotorsBus: TypeAlias = SerialMotorsBus
-18
View File
@@ -1,18 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .robstride import RobstrideMotorsBus
from .tables import *
File diff suppressed because it is too large Load Diff
-120
View File
@@ -1,120 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration tables for Damiao motors."""
from enum import IntEnum
# Motor type definitions
class MotorType(IntEnum):
O0 = 0
O1 = 1
O2 = 2
O3 = 3
O4 = 4
O5 = 5
ELO5 = 6
O6 = 7
class CommMode(IntEnum):
PrivateProtocole = 0
CANopen = 1
MIT = 2
# Control modes
class ControlMode(IntEnum):
MIT = 0
POS_VEL = 1
VEL = 2
# Motor limit parameters [PMAX, VMAX, TMAX]
# PMAX: Maximum position (rad)
# VMAX: Maximum velocity (rad/s)
# TMAX: Maximum torque (N·m)
MOTOR_LIMIT_PARAMS: dict[MotorType, tuple[float, float, float]] = {
MotorType.O0: (12.57, 33, 14),
MotorType.O1: (12.57, 44, 17),
MotorType.O2: (12.57, 33, 20),
MotorType.O3: (12.57, 33, 60),
MotorType.O4: (12.57, 33, 120),
MotorType.O5: (12.57, 50, 5.5),
MotorType.ELO5: (12.57, 50, 6),
MotorType.O6: (112.5, 50, 36),
}
# Motor model names
MODEL_NAMES = {
MotorType.O0: "O0",
MotorType.O1: "O1",
MotorType.O2: "O2",
MotorType.O3: "O3",
MotorType.O4: "O4",
MotorType.O5: "O5",
MotorType.ELO5: "ELO5",
MotorType.O6: "O6",
}
# Motor resolution table (encoder counts per revolution)
MODEL_RESOLUTION = {
"O0": 65536,
"O1": 65536,
"O2": 65536,
"O3": 65536,
"O4": 65536,
"O5": 65536,
"ELO5": 65536,
"O6": 65536,
}
# CAN baudrates supported by Robstride motors
AVAILABLE_BAUDRATES = [
1000000, # 4: 1 mbps (default)
]
DEFAULT_BAUDRATE = 1000000
# Default timeout in milliseconds
DEFAULT_TIMEOUT_MS = 0 # disabled by default, otherwise 20000 is 1s
# Data that should be normalized
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
# MIT control parameter ranges
MIT_KP_RANGE = (0.0, 500.0)
MIT_KD_RANGE = (0.0, 5.0)
# CAN frame command IDs
CAN_CMD_ENABLE = 0xFC
CAN_CMD_DISABLE = 0xFD
CAN_CMD_SET_ZERO = 0xFE
CAN_CMD_CLEAR_FAULT = 0xFB
CAN_CMD_QUERY_PARAM = 0x33
CAN_CMD_WRITE_PARAM = 0x55
CAN_CMD_SAVE_PARAM = 0xAA
# CAN ID for parameter operations
CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.001
PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02
-2
View File
@@ -15,7 +15,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
@@ -29,7 +28,6 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"PI0FastConfig",
+16 -7
View File
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- Either:
@@ -48,12 +48,21 @@ class ACTConfig(PreTrainedConfig):
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_features` and `output_features`.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
@@ -48,29 +48,32 @@ class DiffusionConfig(PreTrainedConfig):
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
backbone. If None, no resizing is done and the original image resolution is used.
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
this is computed automatically. Can also be set directly for legacy configs that use
crop-only (without resize). If None and no derivation applies, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center
crop in eval mode).
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
@@ -120,9 +123,7 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
resize_shape: tuple[int, int] | None = None
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
@@ -147,10 +148,6 @@ class DiffusionConfig(PreTrainedConfig):
# Inference
num_inference_steps: int | None = None
# Optimization
compile_model: bool = False
compile_mode: str = "reduce-overhead"
# Loss computation
do_mask_loss_for_padding: bool = False
@@ -183,25 +180,6 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}."
)
if self.resize_shape is not None and (
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
):
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
if not (0 < self.crop_ratio <= 1.0):
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
if self.resize_shape is not None:
if self.crop_ratio < 1.0:
self.crop_shape = (
int(self.resize_shape[0] * self.crop_ratio),
int(self.resize_shape[1] * self.crop_ratio),
)
else:
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
self.crop_shape = None
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
@@ -229,12 +207,13 @@ class DiffusionConfig(PreTrainedConfig):
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.resize_shape is None and self.crop_shape is not None:
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for `{key}`."
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
@@ -142,9 +142,6 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
for key in self.config.image_features:
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
@@ -185,11 +182,6 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
@@ -454,18 +446,12 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None
crop_shape = config.crop_shape
if crop_shape is not None:
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -491,16 +477,13 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
if config.crop_shape is not None:
dummy_shape_h_w = config.crop_shape
elif config.resize_shape is not None:
dummy_shape_h_w = config.resize_shape
else:
dummy_shape_h_w = images_shape[1:]
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
@@ -516,10 +499,7 @@ class DiffusionRgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: resize if configured, then crop if configured.
if self.resize is not None:
x = self.resize(x)
# Preprocess: maybe crop (if it was set up in the __init__).
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
+6 -21
View File
@@ -18,9 +18,10 @@ from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict, Unpack
from typing import Any, TypedDict
import torch
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
@@ -31,7 +32,6 @@ from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -67,7 +67,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -86,10 +87,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "multi_task_dit":
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
return MultiTaskDiTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
@@ -150,8 +147,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -166,8 +163,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "multi_task_dit":
return MultiTaskDiTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
@@ -314,16 +309,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, MultiTaskDiTConfig):
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
processors = make_multi_task_dit_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
@@ -4,16 +4,17 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import annotations
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from typing import Optional
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
ImagesKwargs,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
@@ -76,7 +77,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
@@ -164,11 +165,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _resize_for_patching(
self,
image: torch.Tensor,
image: "torch.Tensor",
target_resolution: tuple,
interpolation: F.InterpolationMode,
interpolation: "F.InterpolationMode",
input_data_format: ChannelDimension,
) -> torch.Tensor:
) -> "torch.Tensor":
"""
Resizes an image to a target resolution while maintaining aspect ratio.
@@ -218,8 +219,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
return best_ratio
def _pad_for_patching(
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
) -> torch.Tensor:
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
@@ -235,15 +236,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _get_image_patches(
self,
image: torch.Tensor,
image: "torch.Tensor",
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: F.InterpolationMode,
interpolation: "F.InterpolationMode",
pad_during_tiling: bool,
) -> list[torch.Tensor]:
) -> list["torch.Tensor"]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
@@ -304,8 +305,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _pad_for_batching(
self,
pixel_values: list[torch.Tensor],
) -> list[torch.Tensor]:
pixel_values: list["torch.Tensor"],
) -> list["torch.Tensor"]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
@@ -326,14 +327,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _preprocess(
self,
images: list[torch.Tensor],
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: F.InterpolationMode | None,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
@@ -1,37 +0,0 @@
# Multitask DiT Policy
## Citation
If you use this work, please cite the following works:
```bibtex
@misc{jones2025multitaskditpolicy,
author = {Bryson Jones},
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
year = {2025},
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
note = {Blog post}
}
```
```bibtex
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
author = {TRI LBM Team},
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
year = {2025},
eprint = {arXiv:2507.05331},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2507.05331}
}
```
```bibtex
@misc{bostondynamics2025largebehaviormodelsatlas,
author = {Boston Dynamics and TRI Research Team},
title = {Large Behavior Models and Atlas Find New Footing},
year = {2025},
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
note = {Blog post}
}
```
@@ -1,21 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_multi_task_dit import MultiTaskDiTConfig
from .modeling_multi_task_dit import MultiTaskDiTPolicy
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]
@@ -1,256 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamConfig
from lerobot.optim.schedulers import DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("multi_task_dit")
@dataclass
class MultiTaskDiTConfig(PreTrainedConfig):
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
A transformer-based policy that supports both diffusion and flow matching objectives
for multi-task robot learning with text and vision conditioning.
"""
n_obs_steps: int = 2 # Number of observation steps for temporal context
horizon: int = 32 # Number of action steps to predict
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
# Objective Selection
objective: str = "diffusion" # "diffusion" or "flow_matching"
# --- Diffusion-specific (used when objective="diffusion") ---
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
num_train_timesteps: int = 100 # Number of diffusion timesteps
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
beta_start: float = 0.0001 # Starting noise level
beta_end: float = 0.02 # Ending noise level
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
clip_sample: bool = True # Clip samples during denoising
clip_sample_range: float = 1.0 # Clipping range [-x, x]
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
# --- Flow Matching-specific (used when objective="flow_matching") ---
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
num_integration_steps: int = 100 # ODE integration steps at inference
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
# Transformer Architecture
hidden_dim: int = 512 # Transformer hidden dimension
num_layers: int = 6 # Number of transformer layers
num_heads: int = 8 # Number of attention heads
dropout: float = 0.1 # Dropout rate
use_positional_encoding: bool = False # Use absolute positional encoding
timestep_embed_dim: int = 256 # Timestep embedding dimension
use_rope: bool = True # Use Rotary Position Embedding
rope_base: float = 10000.0 # RoPE base frequency
# Vision Encoder (CLIP)
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
image_crop_is_random: bool = True # Random crop during training, center at inference
# Text Encoder (CLIP)
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Training/Optimizer
optimizer_lr: float = 2e-5
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 0
do_mask_loss_for_padding: bool = False
# Auto-calculated
drop_n_last_frames: int | None = None
def __post_init__(self):
super().__post_init__()
if self.drop_n_last_frames is None:
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
self._validate()
def _validate(self):
"""Validate configuration parameters."""
# Objective validation
if self.objective not in ["diffusion", "flow_matching"]:
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
# Transformer validation
if self.hidden_dim <= 0:
raise ValueError("hidden_dim must be positive")
if self.num_layers <= 0:
raise ValueError("num_layers must be positive")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")
if self.hidden_dim % self.num_heads != 0:
raise ValueError("hidden_dim must be divisible by num_heads")
if not (0.0 <= self.dropout <= 1.0):
raise ValueError("dropout must be between 0.0 and 1.0")
# Vision encoder validation
if "clip" not in self.vision_encoder_name.lower():
raise ValueError(
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
)
if (
self.image_resize_shape
and self.image_crop_shape
and (
self.image_crop_shape[0] > self.image_resize_shape[0]
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
logging.warning(
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
self.image_crop_shape,
self.image_resize_shape,
)
self.image_crop_shape = None
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
raise ValueError(
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
)
# Objective-specific validation
if self.objective == "diffusion":
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
raise ValueError(
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
)
if self.prediction_type not in ["epsilon", "sample"]:
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
if self.num_train_timesteps <= 0:
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
elif self.objective == "flow_matching":
if not (0.0 <= self.sigma_min <= 1.0):
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
if self.num_integration_steps <= 0:
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
if self.integration_method not in ["euler", "rk4"]:
raise ValueError(
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
)
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
if self.timestep_sampling_strategy == "beta":
if not (0.0 < self.timestep_sampling_s <= 1.0):
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
if self.timestep_sampling_alpha <= 0:
raise ValueError("timestep_sampling_alpha must be positive")
if self.timestep_sampling_beta <= 0:
raise ValueError("timestep_sampling_beta must be positive")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# If the configured crop doesn't fit, disable cropping instead of erroring.
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
if self.image_crop_shape is not None:
for key, image_ft in self.image_features.items():
# image_ft.shape is (C, H, W)
effective_h, effective_w = (
self.image_resize_shape
if self.image_resize_shape is not None
else (image_ft.shape[1], image_ft.shape[2])
)
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
logging.warning(
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
self.image_crop_shape,
effective_h,
effective_w,
key,
)
self.image_crop_shape = None
break
if len(self.image_features) > 0:
first_key, first_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_ft.shape:
raise ValueError(
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
)
@property
def is_diffusion(self) -> bool:
return self.objective == "diffusion"
@property
def is_flow_matching(self) -> bool:
return self.objective == "flow_matching"
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None
@@ -1,803 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Task Diffusion Transformer (DiT) Policy
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
Supports both diffusion and flow matching objectives for action generation.
References:
- https://arxiv.org/abs/2507.05331
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
"""
import math
from collections import deque
from typing import TYPE_CHECKING
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import CLIPTextModel, CLIPVisionModel
else:
CLIPTextModel = None
CLIPVisionModel = None
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
# -- Policy --
class MultiTaskDiTPolicy(PreTrainedPolicy):
config_class = MultiTaskDiTConfig
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
super().__init__(config)
config.validate_features()
self.config = config
self._queues = None
self.observation_encoder = ObservationEncoder(config)
conditioning_dim = self.observation_encoder.conditioning_dim
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
action_dim = config.action_feature.shape[0]
horizon = config.horizon
if config.is_diffusion:
self.objective = DiffusionObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
elif config.is_flow_matching:
self.objective = FlowMatchingObjective(
config,
action_dim=action_dim,
horizon=horizon,
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
)
else:
raise ValueError(f"Unsupported objective: {config.objective}")
self.reset()
def get_optim_params(self) -> list:
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
non_vision_params = []
vision_encoder_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if "observation_encoder.vision_encoder" in name:
vision_encoder_params.append(param)
else:
non_vision_params.append(param)
return [
{"params": non_vision_params},
{
"params": vision_encoder_params,
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
},
]
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
start = n_obs_steps - 1
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
actions = self._generate_actions(batch)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Prepare batch by stacking image features if needed."""
if self.config.image_features:
batch = dict(batch) # shallow copy to avoid modifying original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
return batch
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch:
batch = dict(batch) # shallow copy to avoid modifying original
batch.pop(ACTION)
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training"""
batch = self._prepare_batch(batch)
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
return loss, None
# -- Observation Encoders --
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
self.model = CLIPVisionModel.from_pretrained(self.model_name)
self.num_non_spatial_tokens = 1
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token."""
outputs = self.model(pixel_values=x, output_hidden_states=False)
cls_token = outputs.last_hidden_state[:, 0]
b, embed_dim = cls_token.shape
return cls_token.reshape(b, embed_dim, 1, 1)
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer.
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
pipeline to see how the tokenization is handled.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Encode pre-tokenized text to feature vectors."""
# Ensure inputs are on the same device as the model
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
with torch.no_grad():
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
clip_features = outputs.pooler_output
return self.projection(clip_features)
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_rgb_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, config):
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
total_dim = 0
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w
total_dim += spatial_feature_dim * self.num_cameras
total_dim += self.robot_state_dim
total_dim += self.text_dim
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES]
if len(images.shape) == 5:
images = images.unsqueeze(1)
if self.config.use_separate_rgb_encoder_per_camera:
camera_features = []
for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx]
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
cam_images_flat = self._apply_preprocessing(cam_images_flat)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
cam_visual_features = cam_features.flatten(start_dim=1)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
text_features = self.text_encoder(input_ids, attention_mask)
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features)
combined_features = torch.cat(conditioning_feats, dim=-1)
return combined_features.flatten(start_dim=1)
# -- Transformer Components --
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero."""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers."""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
B, T, _ = x.shape # noqa: N806
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.rope(q, k)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
)
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
return self.out_proj(attn_out)
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero."""
def __init__(
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
super().__init__()
self.config = config
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if config.use_positional_encoding:
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
)
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
self._initialize_weights()
def _initialize_weights(self):
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep)
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
hidden_seq = self.input_proj(x)
if self.pos_embedding is not None:
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features)
return self.output_proj(hidden_seq)
# -- Objectives --
class DiffusionObjective(nn.Module):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
self.num_inference_steps = (
config.num_inference_steps
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch[ACTION]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"]
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(nn.Module):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling_strategy == "beta":
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
data = batch[ACTION]
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1)
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"]
loss = loss * valid_mask.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x
@@ -1,105 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_multi_task_dit_pre_post_processors(
config: MultiTaskDiTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features.
2. Adding a batch dimension.
3. Tokenizing the language task description (if present).
4. Moving the data to the specified device.
5. Normalizing the input and output features based on dataset statistics.
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output features to their original scale.
2. Moving the data to the CPU.
Args:
config: The configuration object for the Multi-Task DiT policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.text_encoder_name,
padding=config.tokenizer_padding,
padding_side=config.tokenizer_padding_side,
max_length=config.tokenizer_max_length,
truncation=config.tokenizer_truncation,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+54 -68
View File
@@ -15,16 +15,16 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -32,21 +32,13 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
@@ -199,7 +191,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(0.0, 1.0)
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -210,7 +202,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else 0.0
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -229,14 +221,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
models = [paligemma.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -262,10 +254,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -273,7 +265,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -285,15 +277,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -366,7 +358,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -374,7 +366,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
vlm_config_hf.vision_config.torch_dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -385,13 +377,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
torch_dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -406,11 +398,10 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -422,8 +413,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -433,23 +424,15 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.language_model.embed_tokens(tokens)
def forward(
self,
@@ -463,7 +446,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
prefix_output = self.paligemma.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -487,7 +470,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
models = [self.paligemma.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -527,7 +510,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -593,19 +576,29 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
@@ -767,7 +760,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -841,7 +834,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -915,7 +908,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -1005,12 +997,14 @@ class PI0Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Load state dict (expects keys with "model." prefix)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -1018,7 +1012,7 @@ class PI0Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -1031,7 +1025,7 @@ class PI0Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1076,7 +1070,7 @@ class PI0Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
print(f"Warning: Could not remap state dict keys: {e}")
return model
@@ -1126,14 +1120,6 @@ class PI0Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
+60 -71
View File
@@ -15,16 +15,16 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
from typing import TYPE_CHECKING, Literal, TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -32,20 +32,14 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -98,11 +92,10 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,)).to(device)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
@@ -196,7 +189,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(0.0, 1.0)
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -207,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else 0.0
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -226,14 +219,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
models = [paligemma.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -259,10 +252,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -270,7 +263,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -282,15 +275,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -363,7 +356,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -371,7 +364,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
vlm_config_hf.vision_config.torch_dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -382,13 +375,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
torch_dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -403,11 +396,10 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -419,8 +411,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -430,23 +422,15 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.language_model.embed_tokens(tokens)
def forward(
self,
@@ -460,7 +444,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
prefix_output = self.paligemma.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -484,7 +468,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
models = [self.paligemma.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -524,7 +508,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -589,19 +573,29 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
@@ -743,7 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -814,7 +808,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -886,7 +880,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -976,12 +969,14 @@ class PI05Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Load state dict (expects keys with "model." prefix)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -989,7 +984,7 @@ class PI05Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -1002,7 +997,7 @@ class PI05Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1014,6 +1009,8 @@ class PI05Policy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -1047,7 +1044,7 @@ class PI05Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
print(f"Warning: Could not remap state dict keys: {e}")
return model
@@ -1101,14 +1098,6 @@ class PI05Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
@@ -23,6 +23,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -67,6 +68,9 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
tokenizer_max_length: int = 200 # see openpi `__post_init__`
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
action_tokenizer_name: str = "physical-intelligence/fast"
temperature: float = 0.0
max_decoding_steps: int = 256
fast_skip_tokens: int = 128
@@ -19,12 +19,13 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
from typing import TYPE_CHECKING, Literal, TypedDict
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _scipy_available, _transformers_available
@@ -37,16 +38,11 @@ else:
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaModel,
)
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
PaliGemmaForConditionalGeneration = None
AutoTokenizer = None
PiGemmaModel = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
@@ -125,7 +121,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(0.0, 1.0)
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -136,7 +132,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else 0.0
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -210,22 +206,16 @@ class PI0FastPaliGemma(nn.Module):
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
vlm_config_hf.vision_config.torch_dtype = "float32"
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
if use_adarms[0]:
text_config = self.paligemma.config.text_config
self.paligemma.model.language_model = PiGemmaModel(text_config)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.to_bfloat16_for_selected_params(precision)
@@ -238,11 +228,10 @@ class PI0FastPaliGemma(nn.Module):
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -253,18 +242,10 @@ class PI0FastPaliGemma(nn.Module):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.language_model.embed_tokens(tokens)
def forward(
self,
@@ -278,7 +259,7 @@ class PI0FastPaliGemma(nn.Module):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
prefix_output = self.paligemma.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -325,14 +306,24 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
@@ -341,8 +332,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
# Call the proper gradient_checkpointing_disable() method
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
def _apply_checkpoint(self, func, *args, **kwargs):
@@ -532,7 +523,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Convert embeddings to bfloat16 if needed
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -625,7 +616,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -723,7 +714,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Ensure correct precision (bfloat16/float32)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -906,12 +897,14 @@ class PI0FastPolicy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Load state dict (expects keys with "model." prefix)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -919,7 +912,7 @@ class PI0FastPolicy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -932,9 +925,8 @@ class PI0FastPolicy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
@@ -944,6 +936,8 @@ class PI0FastPolicy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -977,7 +971,7 @@ class PI0FastPolicy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
print(f"Warning: Could not remap state dict keys: {e}")
return model
@@ -23,6 +23,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
@@ -68,6 +69,9 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
-363
View File
@@ -1,363 +0,0 @@
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch import nn
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaConfig,
GemmaForCausalLM,
GemmaMLP,
GemmaModel,
)
from transformers.models.paligemma.modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaModel,
)
else:
GemmaAttention = None
GemmaConfig = None
GemmaForCausalLM = None
GemmaMLP = None
GemmaModel = None
PaliGemmaModel = None
PaliGemmaForConditionalGeneration = None
DynamicCache = None
GradientCheckpointingLayer = None
BaseModelOutputWithPast = None
create_causal_mask = None
def _gated_residual(
x: torch.Tensor | None,
y: torch.Tensor | None,
gate: torch.Tensor | None,
) -> torch.Tensor | None:
"""Gated residual: x + y when gate is None, else x + y * gate."""
if x is None and y is None:
return None
if x is None or y is None:
return x if x is not None else y
if gate is None:
return x + y
return x + y * gate
def layernorm_forward(
layernorm: nn.Module,
x: torch.Tensor,
cond: torch.Tensor | None = None,
):
"""
call layernorm and return hidden states and gate
if cond is not None, use conditional norm
otherwise, use normal gemma norm
"""
if cond is not None:
return layernorm(x, cond=cond)
else:
return layernorm(x)
class PiGemmaRMSNorm(nn.Module):
"""
Adaptive RMSNorm for PI Gemma (AdaRMS).
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
"""
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
self.eps = eps
self.dim = dim
self.cond_dim = cond_dim
if cond_dim is not None:
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
nn.init.zeros_(self.dense.weight)
else:
self.weight = nn.Parameter(torch.zeros(dim))
self.dense = None
def _norm(self, x):
# Compute variance in float32 (like the source implementation)
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
# Compute normalization in float32
normed_inputs = x * torch.rsqrt(var + self.eps)
return normed_inputs
def forward(
self,
x: torch.Tensor,
cond: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
dtype = x.dtype
normed = self._norm(x)
if cond is None or self.dense is None:
normed = normed * (1.0 + self.weight.float())
return normed.type_as(x), None
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
modulation = self.dense(cond)
if len(x.shape) == 3:
modulation = modulation.unsqueeze(1)
scale, shift, gate = modulation.chunk(3, dim=-1)
normed = normed * (1 + scale.float()) + shift.float()
return normed.to(dtype), gate.to(dtype)
def extra_repr(self) -> str:
if self.dense is not None:
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
return f"dim={self.dim}, eps={self.eps}"
def _get_pi_gemma_decoder_layer_base():
"""base for PiGemmaDecoderLayer"""
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
cond_dim = (
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
)
self.input_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
self.post_attention_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
hidden_states, _ = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = _gated_residual(residual, hidden_states, gate)
residual = hidden_states
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
hidden_states = self.mlp(hidden_states)
hidden_states = _gated_residual(residual, hidden_states, gate)
return hidden_states
return _PiGemmaDecoderLayerBase
class PiGemmaModel(GemmaModel): # type: ignore[misc]
"""
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
# if not getattr(config, "use_adarms", False):
# return
cond_dim = getattr(config, "adarms_cond_dim", None)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutputWithPast:
"""
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
import logging
logging.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = layer_outputs
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states, _ = self.norm(hidden_states, adarms_cond)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
"""
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = PiGemmaModel(config)
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
def __init__(self, config):
super().__init__(config)
self.language_model = PiGemmaModel(config.text_config)
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
def __init__(self, config):
super().__init__(config)
self.model = PaliGemmaModelWithPiGemma(config)
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model
__all__ = [
"PiGemmaModel",
"PiGemmaForCausalLM",
"PiGemmaRMSNorm",
"_gated_residual",
"layernorm_forward",
"PaliGemmaModelWithPiGemma",
"PaliGemmaForConditionalGenerationWithPiGemma",
]
+2 -1
View File
@@ -19,7 +19,7 @@ import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypedDict, TypeVar, Unpack
from typing import TypedDict, TypeVar
import packaging
import safetensors
@@ -28,6 +28,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
@@ -27,18 +27,18 @@ Usage:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
--reward-model-path pepijn223/sarm_single_uni4
# Faster computation with stride (compute every 5 frames, interpolate the rest)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--stride 5
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 5
@@ -714,12 +714,12 @@ Examples:
# Full RA-BC computation with visualizations
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4
--reward-model-path pepijn223/sarm_single_uni4
# Visualize predictions only (no RA-BC computation)
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
--reward-model-path <USER>/sarm_single_uni4 \\
--reward-model-path pepijn223/sarm_single_uni4 \\
--visualize-only \\
--num-visualizations 10
""",
+3 -1
View File
@@ -277,7 +277,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
if self.dataset_meta is not None:
episodes_df = self.dataset_meta.episodes.to_pandas()
episodes_df = None
if self.sparse_subtask_names != ["task"]:
episodes_df = self.dataset_meta.episodes.to_pandas()
# Generate sparse targets
if self.sparse_temporal_proportions is not None:

Some files were not shown because too many files have changed in this diff Show More