diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index 2fb23051c..7423495de 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -25,7 +25,7 @@ body:
id: system-info
attributes:
label: System Info
- description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below
+ description: Please share your LeRobot configuration by running `lerobot-info` (if installed) or `python -m lerobot.scripts.display_sys_info` (if not installed) and pasting the output below.
render: Shell
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
validations:
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index df2e2db29..d37b1a92f 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,33 +1,40 @@
## What this does
+
Explain what this PR does. Feel free to tag your PR with the appropriate label(s).
Examples:
-| Title | Label |
+| Title | Label |
|----------------------|-----------------|
-| Fixes #[issue] | (🐛 Bug) |
-| Adds new dataset | (🗃️ Dataset) |
-| Optimizes something | (⚡️ Performance) |
+| Fixes #[issue] | (🐛 Bug) |
+| Adds new dataset | (🗃️ Dataset) |
+| Optimizes something | (⚡️ Performance) |
## How it was tested
+
Explain/show how you tested your changes.
Examples:
+
- Added `test_something` in `tests/test_stuff.py`.
- Added `new_feature` and checked that training converges with policy X on dataset/environment Y.
- Optimized `some_function`, it now runs X times faster than previously.
## How to checkout & try? (for the reviewer)
+
Provide a simple way for the reviewer to try out your changes.
Examples:
+
```bash
pytest -sx tests/test_stuff.py::test_something
```
+
```bash
-python -m lerobot.scripts.train --some.option=true
+lerobot-train --some.option=true
```
## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR
+
**Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
diff --git a/.github/workflows/build-docker-images.yml b/.github/workflows/build-docker-images.yml
deleted file mode 100644
index 20974b85a..000000000
--- a/.github/workflows/build-docker-images.yml
+++ /dev/null
@@ -1,135 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Inspired by
-# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
-name: Builds
-
-on:
- workflow_dispatch:
- workflow_call:
- schedule:
- - cron: "0 1 * * *"
-
-permissions: {}
-
-env:
- PYTHON_VERSION: "3.10"
-
-jobs:
- latest-cpu:
- name: CPU
- runs-on:
- group: aws-general-8-plus
- steps:
- - name: Install Git LFS
- run: |
- sudo apt-get update
- sudo apt-get install git-lfs
- git lfs install
-
- - name: Set up Docker Buildx
- uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- with:
- cache-binary: false
-
- - name: Check out code
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- lfs: true
- persist-credentials: false
-
- - name: Login to DockerHub
- uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
- with:
- username: ${{ secrets.DOCKERHUB_USERNAME }}
- password: ${{ secrets.DOCKERHUB_PASSWORD }}
-
- - name: Build and Push CPU
- uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
- with:
- context: .
- file: ./docker/lerobot-cpu/Dockerfile
- push: true
- tags: huggingface/lerobot-cpu
- build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
-
-
- latest-cuda:
- name: GPU
- runs-on:
- group: aws-general-8-plus
- steps:
- - name: Install Git LFS
- run: |
- sudo apt-get update
- sudo apt-get install git-lfs
- git lfs install
-
- - name: Set up Docker Buildx
- uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- with:
- cache-binary: false
-
- - name: Check out code
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- lfs: true
- persist-credentials: false
-
- - name: Login to DockerHub
- uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
- with:
- username: ${{ secrets.DOCKERHUB_USERNAME }}
- password: ${{ secrets.DOCKERHUB_PASSWORD }}
-
- - name: Build and Push GPU
- uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
- with:
- context: .
- file: ./docker/lerobot-gpu/Dockerfile
- push: true
- tags: huggingface/lerobot-gpu
- build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
-
-
- latest-cuda-dev:
- name: GPU Dev
- runs-on:
- group: aws-general-8-plus
- steps:
- - name: Set up Docker Buildx
- uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- with:
- cache-binary: false
-
- - name: Check out code
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- persist-credentials: false
-
- - name: Login to DockerHub
- uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
- with:
- username: ${{ secrets.DOCKERHUB_USERNAME }}
- password: ${{ secrets.DOCKERHUB_PASSWORD }}
-
- - name: Build and Push GPU dev
- uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
- with:
- context: .
- file: ./docker/lerobot-gpu-dev/Dockerfile
- push: true
- tags: huggingface/lerobot-gpu:dev
- build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml
deleted file mode 100644
index 884e2e4b5..000000000
--- a/.github/workflows/build_documentation.yml
+++ /dev/null
@@ -1,23 +0,0 @@
-name: Build documentation
-
-on:
- workflow_dispatch:
- push:
- paths:
- - "docs/**"
- branches:
- - main
- - doc-builder*
- - v*-release
-
-
-jobs:
- build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers
- uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
- with:
- commit_sha: ${{ github.sha }}
- package: lerobot
- additional_args: --not_python_module
- secrets:
- token: ${{ secrets.HUGGINGFACE_PUSH }}
- hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml
deleted file mode 100644
index 51bab10d5..000000000
--- a/.github/workflows/build_pr_documentation.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Build PR Documentation
-
-on:
- pull_request:
- paths:
- - "docs/**"
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
-
-jobs:
- build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers
- uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
- with:
- commit_sha: ${{ github.event.pull_request.head.sha }}
- pr_number: ${{ github.event.number }}
- package: lerobot
- additional_args: --not_python_module
diff --git a/.github/workflows/documentation-upload-pr.yml b/.github/workflows/documentation-upload-pr.yml
new file mode 100644
index 000000000..22ba11cbb
--- /dev/null
+++ b/.github/workflows/documentation-upload-pr.yml
@@ -0,0 +1,40 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow uploads the documentation preview built for a PR and comments the link on the PR.
+name: Documentation PR Upload
+permissions:
+ contents: read
+ pull-requests: write
+
+on:
+ # Triggered by the completion of the main 'Documentation' workflow.
+ workflow_run: # zizmor: ignore[dangerous-triggers] We follow the same pattern as in Transformers
+ workflows: ["Documentation"]
+ types:
+ - completed
+
+jobs:
+ # This job uploads a preview of the documentation for a pull request.
+ upload_and_comment:
+ name: Upload Preview and Comment
+ if: >
+ github.event.workflow_run.event == 'pull_request' &&
+ github.event.workflow_run.conclusion == 'success'
+ uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
+ with:
+ package_name: lerobot
+ secrets:
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
+ comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml
new file mode 100644
index 000000000..96005af3f
--- /dev/null
+++ b/.github/workflows/documentation.yml
@@ -0,0 +1,70 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow handles building documentation for both main branches and PRs.
+name: Documentation
+
+on:
+ # Allows running this workflow manually from the Actions tab
+ workflow_dispatch:
+
+ # Triggers the workflow on push events to main for the docs folder
+ push:
+ branches:
+ - main
+ paths:
+ - "docs/**"
+
+ # Triggers the workflow on pull request events targeting main for the docs folder
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "docs/**"
+
+# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ # This job builds and deploys the official documentation.
+ build_main_docs:
+ name: Build Main Docs
+ if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
+ permissions:
+ contents: read
+ uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
+ with:
+ commit_sha: ${{ github.sha }}
+ package: lerobot
+ additional_args: --not_python_module
+ secrets:
+ token: ${{ secrets.HUGGINGFACE_PUSH }}
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
+
+ # This job builds a preview of the documentation for a pull request.
+ # The result of this job triggers the 'Upload PR Documentation' workflow.
+ build_pr_docs:
+ name: Build PR Docs
+ if: github.event_name == 'pull_request'
+ permissions:
+ contents: read
+ pull-requests: write
+ uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
+ with:
+ commit_sha: ${{ github.event.pull_request.head.sha }}
+ pr_number: ${{ github.event.number }}
+ package: lerobot
+ additional_args: --not_python_module
diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml
new file mode 100644
index 000000000..ad4938970
--- /dev/null
+++ b/.github/workflows/fast_tests.yml
@@ -0,0 +1,87 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow handles fast testing.
+name: Fast Tests
+
+on:
+ # Allows running this workflow manually from the Actions tab
+ workflow_dispatch:
+
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "src/**"
+ - "tests/**"
+ - ".github/workflows/**"
+ - "pyproject.toml"
+ - "Makefile"
+ push:
+ branches:
+ - main
+ paths:
+ - "src/**"
+ - "tests/**"
+ - ".github/workflows/**"
+ - "pyproject.toml"
+ - "Makefile"
+
+permissions:
+ contents: read
+
+# Sets up the environment variables
+env:
+ UV_VERSION: "0.8.0"
+ PYTHON_VERSION: "3.10"
+ DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
+
+# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ # This job runs pytests with the default dependencies.
+ # It runs everytime we commit to a PR or push to main
+ fast-pytest-tests:
+ name: Fast Pytest Tests
+ runs-on: ubuntu-latest
+ env:
+ MUJOCO_GL: egl
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ persist-credentials: false
+ lfs: true
+
+ # TODO(Steven): Evaluate the need of these dependencies
+ - name: Install apt dependencies
+ run: |
+ sudo apt-get update && sudo apt-get install -y build-essential git \
+ curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
+ libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev
+
+ - name: Setup uv and Python
+ uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ enable-cache: true
+ version: ${{ env.UV_VERSION }}
+ python-version: ${{ env.PYTHON_VERSION }}
+
+ - name: Install lerobot with test extras
+ run: uv sync --extra "test"
+
+ - name: Run pytest
+ run: uv run pytest tests -vv --maxfail=10
diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml
new file mode 100644
index 000000000..d16fe5e72
--- /dev/null
+++ b/.github/workflows/full_tests.yml
@@ -0,0 +1,210 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow handles full testing.
+name: Full Tests
+
+on:
+ # Allows running this workflow manually from the Actions tab
+ workflow_dispatch:
+
+ pull_request_review:
+ types: [submitted]
+ push:
+ branches:
+ - main
+ paths:
+ - "src/**"
+ - "tests/**"
+ - ".github/workflows/**"
+ - "pyproject.toml"
+ - "Makefile"
+
+permissions:
+ contents: read
+
+# Sets up the environment variables
+env:
+ UV_VERSION: "0.8.0"
+ PYTHON_VERSION: "3.10"
+ DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
+
+# Ensures that only the latest action is built, canceling older runs.
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+
+ # This job runs the E2E tests + pytest with all extras
+ # It runs everytime a PR is approved or a push to main
+ full-tests:
+ name: Full Tests
+ runs-on: ubuntu-latest
+ if: |
+ (github.event_name == 'pull_request_review' && github.event.review.state == 'approved') ||
+ github.event_name == 'push' ||
+ github.event_name == 'workflow_dispatch'
+ env:
+ MUJOCO_GL: egl
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ persist-credentials: false
+
+ - name: Install apt dependencies
+ run: |
+ sudo apt-get update && sudo apt-get install -y build-essential \
+ git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
+ speech-dispatcher libgeos-dev portaudio19-dev
+
+ - name: Setup uv and Python
+ uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ enable-cache: true
+ version: ${{ env.UV_VERSION }}
+ python-version: ${{ env.PYTHON_VERSION }}
+
+ - name: Install lerobot with all extras
+ run: uv sync --all-extras
+
+ - name: Run pytest (all extras)
+ run: uv run pytest tests -vv --maxfail=10
+
+ - name: Run end-to-end tests
+ run: uv run make test-end-to-end
+
+ # This job builds a GPU enabled image for testing
+ # It runs everytime a PR is approved or a push to main
+ # TODO(Steven): For now we skip this job for community PRs
+ build-and-push-docker:
+ name: Build and Push Docker
+ runs-on:
+ group: aws-general-8-plus
+ if: |
+ (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
+ github.event_name == 'push' ||
+ github.event_name == 'workflow_dispatch'
+ outputs:
+ image_tag: ${{ steps.set_tag.outputs.image_tag }}
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_REF: ${{ github.ref }}
+ GITHUB_PR_NUMBER: ${{ github.event.pull_request.number }}
+ steps:
+ - name: Set Docker image tag
+ id: set_tag
+ run: |
+ if [[ "${GITHUB_EVENT_NAME}" == "push" ]]; then
+ TAG="${DOCKER_IMAGE_NAME}:latest"
+ elif [[ -n "${GITHUB_PR_NUMBER}" ]]; then
+ TAG="${DOCKER_IMAGE_NAME}:pr-${GITHUB_PR_NUMBER}"
+ else
+ TAG="${DOCKER_IMAGE_NAME}:pr-${GITHUB_REF##*/}"
+ fi
+ echo "image_tag=$TAG" >> $GITHUB_OUTPUT
+ - name: Install Git LFS
+ run: |
+ sudo apt-get update
+ sudo apt-get install git-lfs
+ git lfs install
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ persist-credentials: false
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ cache-binary: false
+ - name: Login to Docker Hub
+ uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ - name: Build and push Docker image
+ uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ context: .
+ file: ./docker/Dockerfile.internal
+ push: true
+ tags: ${{ steps.set_tag.outputs.image_tag }}
+
+ # This job runs pytest with all extras in a GPU enabled host
+ # It runs everytime a test image is created
+ gpu-tests:
+ name: GPU Tests
+ needs: [build-and-push-docker]
+ runs-on:
+ group: aws-g6-4xlarge-plus
+ env:
+ HF_HOME: /home/user_lerobot/.cache/huggingface
+ HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
+ TORCH_HOME: /home/user_lerobot/.cache/torch
+ TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
+ container:
+ image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
+ options: --gpus all --shm-size "16gb"
+ credentials:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: /lerobot
+ steps:
+ - name: Run pytest on GPU
+ run: pytest tests -vv --maxfail=10
+ - name: Run end-to-end tests
+ run: make test-end-to-end
+
+ # This job deletes the test image recently created
+ # It runs everytime after the gpu-tests have finished
+ delete-pr-image:
+ name: Delete PR Image
+ needs: [gpu-tests, build-and-push-docker]
+ if: always() && ((github.event.review.state == 'approved') || (github.event_name == 'workflow_dispatch')) && needs.build-and-push-docker.result == 'success'
+ runs-on: ubuntu-latest
+ steps:
+ - name: Get Docker Hub Token and Delete Image
+ # zizmor: ignore[template-injection]
+ run: |
+ IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
+ IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
+
+ echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
+
+ TOKEN=$(curl -s -H "Content-Type: application/json" \
+ -X POST \
+ -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
+ https://hub.docker.com/v2/users/login/ | jq -r .token)
+
+ if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
+ echo "::error::Failed to get Docker Hub token."
+ exit 1
+ fi
+
+ HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
+ -H "Authorization: JWT ${TOKEN}" \
+ -X DELETE \
+ https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
+
+ if [ "$HTTP_RESPONSE" -eq 204 ]; then
+ echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"
+ else
+ echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE"
+ exit 1
+ fi
+
+# TODO(Steven): Check dockerimages pull in ubuntu
diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml
deleted file mode 100644
index 728016915..000000000
--- a/.github/workflows/nightly-tests.yml
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Inspired by
-# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
-name: Nightly
-
-on:
- workflow_dispatch:
- schedule:
- - cron: "0 2 * * *"
-
-permissions: {}
-
-# env:
- # SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
-jobs:
- run_all_tests_cpu:
- name: CPU
- strategy:
- fail-fast: false
- runs-on:
- group: aws-general-8-plus
- container:
- image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
- options: --shm-size "16gb"
- credentials:
- username: ${{ secrets.DOCKERHUB_USERNAME }}
- password: ${{ secrets.DOCKERHUB_PASSWORD }}
- defaults:
- run:
- shell: bash
- working-directory: /lerobot
- steps:
- - name: Tests
- run: pytest -v --cov=./src/lerobot --disable-warnings tests
-
- - name: Tests end-to-end
- run: make test-end-to-end
-
-
- run_all_tests_single_gpu:
- name: GPU
- strategy:
- fail-fast: false
- runs-on:
- group: aws-g6-4xlarge-plus
- env:
- CUDA_VISIBLE_DEVICES: "0"
- TEST_TYPE: "single_gpu"
- container:
- image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
- options: --gpus all --shm-size "16gb"
- credentials:
- username: ${{ secrets.DOCKERHUB_USERNAME }}
- password: ${{ secrets.DOCKERHUB_PASSWORD }}
- defaults:
- run:
- shell: bash
- working-directory: /lerobot
- steps:
- - name: Nvidia-smi
- run: nvidia-smi
-
- - name: Test
- run: pytest -v --cov=./src/lerobot --cov-report=xml --disable-warnings tests
- # TODO(aliberts): Link with HF Codecov account
- # - name: Upload coverage reports to Codecov with GitHub Action
- # uses: codecov/codecov-action@v4
- # with:
- # files: ./coverage.xml
- # verbose: true
- - name: Tests end-to-end
- env:
- DEVICE: cuda
- run: make test-end-to-end
-
- # - name: Generate Report
- # if: always()
- # run: |
- # pip install slack_sdk tabulate
- # python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml
new file mode 100644
index 000000000..03f26a792
--- /dev/null
+++ b/.github/workflows/nightly.yml
@@ -0,0 +1,160 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow handles nightly testing & docker images publishing.
+name: Nightly
+permissions:
+ contents: read
+
+on:
+ # Allows running this workflow manually from the Actions tab
+ workflow_dispatch:
+
+ # Runs at 02:00
+ schedule:
+ - cron: "0 2 * * *"
+
+# Sets up the environment variables
+env:
+ UV_VERSION: "0.8.0"
+ PYTHON_VERSION: "3.10"
+ DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
+ DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
+
+# Ensures that only the latest commit is built, canceling older runs.
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ # This job builds a CPU image for testing & distribution
+ build-docker-cpu-nightly:
+ name: Build CPU Docker for Nightly
+ runs-on:
+ group: aws-general-8-plus
+ outputs:
+ image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
+ steps:
+ - name: Install Git LFS
+ run: |
+ sudo apt-get update
+ sudo apt-get install git-lfs
+ git lfs install
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ persist-credentials: false
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ cache-binary: false
+ - name: Login to Docker Hub
+ uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ - name: Build and push Docker image CPU
+ uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ context: .
+ file: ./docker/Dockerfile.user
+ push: true
+ tags: ${{ env.DOCKER_IMAGE_NAME_CPU }}
+
+ # This job builds a GPU image for testing & distribution
+ build-docker-gpu-nightly:
+ name: Build GPU Docker for Nightly
+ runs-on:
+ group: aws-general-8-plus
+ outputs:
+ image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
+ steps:
+ - name: Install Git LFS
+ run: |
+ sudo apt-get update
+ sudo apt-get install git-lfs
+ git lfs install
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ persist-credentials: false
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ cache-binary: false
+ - name: Login to Docker Hub
+ uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ - name: Build and push Docker image GPU
+ uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ context: .
+ file: ./docker/Dockerfile.internal
+ push: true
+ tags: ${{ env.DOCKER_IMAGE_NAME_GPU }}
+
+ # This job runs the E2E tests + pytest with all extras in the CPU image
+ nightly-cpu-tests:
+ name: Nightly CPU Tests
+ needs: [build-docker-cpu-nightly]
+ runs-on:
+ group: aws-g6-4xlarge-plus
+ env:
+ HF_HOME: /home/user_lerobot/.cache/huggingface
+ HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
+ TORCH_HOME: /home/user_lerobot/.cache/torch
+ TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
+ container:
+ image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
+ credentials:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: /lerobot
+ steps:
+ - name: Run pytest on CPU
+ run: pytest tests -vv --maxfail=10
+ - name: Run end-to-end tests
+ run: make test-end-to-end
+
+ # This job runs the E2E tests + pytest with all extras in the GPU image
+ nightly-gpu-tests:
+ name: Nightly GPU Tests
+ needs: [build-docker-gpu-nightly]
+ runs-on:
+ group: aws-g6-4xlarge-plus
+ env:
+ HF_HOME: /home/user_lerobot/.cache/huggingface
+ HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
+ TORCH_HOME: /home/user_lerobot/.cache/torch
+ TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
+ container:
+ image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
+ options: --gpus all --shm-size "16gb"
+ credentials:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: /lerobot
+ steps:
+ - name: Run pytest on GPU
+ run: pytest tests -vv --maxfail=10
+ - name: Run end-to-end tests
+ run: make test-end-to-end
diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml
index 1c048c4fe..e9f73ed23 100644
--- a/.github/workflows/quality.yml
+++ b/.github/workflows/quality.yml
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,61 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# This workflow handles linting, formatting, and static analysis checks for the codebase.
name: Quality
+permissions:
+ contents: read
on:
+ # Allows running this workflow manually from the Actions tab
workflow_dispatch:
- workflow_call:
- pull_request:
+
+ # Triggers the workflow on push events to main
push:
branches:
- main
-permissions: {}
+ # Triggers the workflow on pull request events targeting main
+ pull_request:
+ branches:
+ - main
-env:
- PYTHON_VERSION: "3.10"
+# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
jobs:
- style:
- name: Style
+ # This job runs pre-commit hooks to check code style and formatting.
+ pre-commit-checks:
+ name: Run Pre-commit Hooks (Lint, Format & Static Analysis)
runs-on: ubuntu-latest
steps:
- - name: Checkout Repository
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - name: Checkout code
+ uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
- uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
+ uses: actions/setup-python@v5
with:
- python-version: ${{ env.PYTHON_VERSION }}
+ python-version: '3.10'
- - name: Get Ruff Version from pre-commit-config.yaml
- id: get-ruff-version
- run: |
- RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
- echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT
-
- - name: Install Ruff
- env:
- RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }}
- run: python -m pip install "ruff==${RUFF_VERSION}"
-
- - name: Ruff check
- run: ruff check --output-format=github
-
- - name: Ruff format
- run: ruff format --diff
-
- typos:
- name: Typos
- runs-on: ubuntu-latest
- steps:
- - name: Checkout Repository
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - name: Run pre-commit hooks
+ uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
with:
- persist-credentials: false
-
- - name: typos-action
- uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10
+ extra_args: --all-files --show-diff-on-failure --color=always
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 000000000..67aa5186b
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,171 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: Create Release and Publish to PyPI
+
+on:
+ push:
+ tags:
+ - 'v*.*.*' # Trigger on tags like v0.1.0, v1.0.0
+
+# Sets up the environment variables
+env:
+ UV_VERSION: "0.8.0"
+ PYTHON_VERSION: "3.10"
+
+jobs:
+ # This job builds the Python package and publishes it to PyPI
+ build-and-publish:
+ name: Build and publish Python distributions
+ runs-on: ubuntu-latest
+ outputs:
+ version: ${{ steps.extract_info.outputs.tag_version }}
+ permissions:
+ contents: write
+ id-token: write
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ persist-credentials: false
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Extract Version
+ id: extract_info
+ # Extract version from tag (e.g., v0.1.0 -> 0.1.0)
+ # zizmor: ignore[template-injection]
+ run: |
+ VERSION=${{ github.ref_name }}
+ VERSION_NUMBER=${VERSION#v}
+ echo "tag_version=$VERSION_NUMBER" >> $GITHUB_OUTPUT
+ - name: Check if version matches pyproject.toml
+ if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
+ # zizmor: ignore[template-injection]
+ run: |
+ TAG_VERSION=${{ steps.extract_info.outputs.tag_version }}
+
+ PYPROJECT_VERSION=$(grep '^version = ' pyproject.toml | awk -F' = ' '{print $2}' | tr -d '"')
+
+ if [[ "$TAG_VERSION" != "$PYPROJECT_VERSION" ]]; then
+ echo "Error: Tag version ($TAG_VERSION) does not match pyproject.toml version ($PYPROJECT_VERSION)." >&2
+ exit 1
+ else
+ echo "Tag version matches pyproject.toml version: $TAG_VERSION. Proceeding with release."
+ fi
+
+ - name: Check if version exists on PyPI
+ # zizmor: ignore[template-injection]
+ run: |
+ NEW_VERSION=${{ steps.extract_info.outputs.tag_version }}
+
+ response=$(curl -s "https://pypi.org/pypi/lerobot/$NEW_VERSION/json")
+ if echo "$response" | grep -q "message"; then
+ echo "Version $NEW_VERSION is available on PyPI. Proceeding with release."
+ else
+ echo "Error: Version $NEW_VERSION already exists on PyPI. Aborting."
+ exit 1
+ fi
+
+ - name: Install build dependencies
+ run: python -m pip install build
+
+ - name: Build package
+ run: python -m build
+
+ - name: Create GitHub Release
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ # zizmor: ignore[template-injection]
+ run: |
+ gh release create ${{ github.ref_name }} \
+ --title "Release ${{ github.ref_name }}" \
+ --generate-notes \
+ --draft=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \
+ --prerelease=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \
+ ./dist/*
+
+ - name: Publish to TestPyPI for pre-releases
+ # True for tags like 'v0.2.0-rc1'
+ if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
+ uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
+ with:
+ repository-url: https://test.pypi.org/legacy/
+ verbose: true
+ print-hash: true
+
+ - name: Publish to PyPI
+ if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
+ uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
+ with:
+ verbose: true
+ print-hash: true
+
+ # This job runs end-to-end tests on the release
+ test-release:
+ name: Test Release
+ needs: [build-and-publish]
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ env:
+ MUJOCO_GL: egl
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ persist-credentials: false
+ - name: Install apt dependencies
+ run: |
+ sudo apt-get update && sudo apt-get install -y build-essential \
+ git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
+ speech-dispatcher libgeos-dev portaudio19-dev
+ - name: Setup uv and Python
+ uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ enable-cache: true
+ version: ${{ env.UV_VERSION }}
+ python-version: ${{ env.PYTHON_VERSION }}
+ - name: Create uv virtual environment
+ run: uv venv
+ - name: Install lerobot release
+ # zizmor: ignore[template-injection]
+ run: |
+ VERSION="${{ needs.build-and-publish.outputs.version }}"
+ if [[ "$VERSION" == *-* ]]; then
+ BASE_VERSION="${VERSION%%-*}"
+ echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
+ uv pip install \
+ --index-url https://test.pypi.org/simple/ \
+ --extra-index-url https://pypi.org/simple \
+ --index-strategy unsafe-best-match \
+ "lerobot[all]==$BASE_VERSION"
+ else
+ echo "Installing release version $VERSION from PyPI..."
+ uv pip install "lerobot[all]==$VERSION"
+ fi
+ - name: Check lerobot version
+ run: uv run python -c "import lerobot; print(lerobot.__version__)"
+
+ - name: Run end-to-end tests
+ run: uv run make test-end-to-end
+
+
+# TODO(Steven): Publish draft/pre-release and to test pypi weekly
+# TODO(Steven): Separate build and publish job
+# TODO(Steven): Tag documentation with the same version as the package
diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml
new file mode 100644
index 000000000..04497307b
--- /dev/null
+++ b/.github/workflows/security.yml
@@ -0,0 +1,54 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow handles secret scanning using TruffleHog to detect sensitive information in the codebase.
+name: Security
+permissions:
+ contents: read
+
+on:
+ # Allows running this workflow manually from the Actions tab
+ workflow_dispatch:
+
+ # Triggers the workflow on push events to main
+ push:
+ branches:
+ - main
+
+ # Triggers the workflow on pull request events targeting main
+ pull_request:
+ branches:
+ - main
+
+# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ # This job runs TruffleHog to scan the full history of the repository for secrets.
+ trufflehog:
+ name: Secret Leaks Scan
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4 # zizmor: ignore[unpinned-uses]
+ with:
+ fetch-depth: 0
+ persist-credentials: false
+
+ - name: Secret Scanning
+ uses: trufflesecurity/trufflehog@v3.90.0 # zizmor: ignore[unpinned-uses]
+ with:
+ extra_args: --only-verified
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
new file mode 100644
index 000000000..af91c9f58
--- /dev/null
+++ b/.github/workflows/stale.yml
@@ -0,0 +1,68 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow handles closing stale issues and PRs.
+name: Stale
+on:
+ # Allows running this workflow manually from the Actions tab
+ workflow_dispatch:
+
+ # Runs at 02:00
+ schedule:
+ - cron: "0 2 * * *"
+
+env:
+ CLOSE_ISSUE_MESSAGE: >
+ This issue was closed because it has been stalled for 14 days with no activity.
+ Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
+ CLOSE_PR_MESSAGE: >
+ This PR was closed because it has been stalled for 14 days with no activity.
+ Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
+ WARN_ISSUE_MESSAGE: >
+ This issue has been automatically marked as stale because it has not had
+ recent activity (6 months). It will be closed if no further activity occurs.
+ Thank you for your contributions.
+ WARN_PR_MESSAGE: >
+ This PR has been automatically marked as stale because it has not had
+ recent activity (6 months). It will be closed if no further activity occurs.
+ Thank you for your contributions.
+
+jobs:
+ # This job runs the actions/stale action to close stale issues and PRs.
+ stale:
+ name: Close Stale Issues and PRs
+ runs-on: ubuntu-latest
+ permissions:
+ actions: write
+ contents: write # only for delete-branch option
+ issues: write
+ pull-requests: write
+ steps:
+ - uses: actions/stale@v10
+ with:
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
+ stale-issue-label: stale
+ stale-pr-label: stale
+ exempt-issue-labels: never-stale
+ exempt-pr-labels: never-stale
+ days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
+ days-before-issue-close: 14
+ days-before-pr-stale: 180
+ days-before-pr-close: 14
+ delete-branch: true
+ close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
+ close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
+ stale-issue-message: ${{ env.WARN_ISSUE_MESSAGE }}
+ stale-pr-message: ${{ env.WARN_PR_MESSAGE }}
+ operations-per-run: 500
diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml
deleted file mode 100644
index 7a1e93274..000000000
--- a/.github/workflows/test-docker-build.yml
+++ /dev/null
@@ -1,82 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Inspired by
-# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
-name: Test Dockerfiles
-
-on:
- pull_request:
- paths:
- # Run only when DockerFile files are modified
- - "docker/**"
-
-permissions: {}
-
-env:
- PYTHON_VERSION: "3.10"
-
-jobs:
- get_changed_files:
- name: Detect modified Dockerfiles
- runs-on: ubuntu-latest
- outputs:
- matrix: ${{ steps.set-matrix.outputs.matrix }}
- steps:
- - name: Check out code
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- persist-credentials: false
-
- - name: Get changed files
- id: changed-files
- uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
- with:
- files: docker/**
- json: "true"
-
- - name: Run step if only the files listed above change # zizmor: ignore[template-injection]
- if: steps.changed-files.outputs.any_changed == 'true'
- id: set-matrix
- run: |
- echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
-
- build_modified_dockerfiles:
- name: Build modified Docker images
- needs: get_changed_files
- runs-on:
- group: aws-general-8-plus
- if: needs.get_changed_files.outputs.matrix != ''
- strategy:
- fail-fast: false
- matrix:
- docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
- steps:
- - name: Set up Docker Buildx
- uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
- with:
- cache-binary: false
-
- - name: Check out code
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- persist-credentials: false
-
- - name: Build Docker image
- uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
- with:
- file: ${{ matrix.docker-file }}
- context: .
- push: False
- build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
deleted file mode 100644
index d6ea1d404..000000000
--- a/.github/workflows/test.yml
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: Tests
-
-on:
- pull_request:
- paths:
- - "src/**"
- - "tests/**"
- - "examples/**"
- - ".github/**"
- - "pyproject.toml"
- - ".pre-commit-config.yaml"
- - "Makefile"
- - ".cache/**"
- push:
- branches:
- - main
- paths:
- - "src/**"
- - "tests/**"
- - "examples/**"
- - ".github/**"
- - "pyproject.toml"
- - ".pre-commit-config.yaml"
- - "Makefile"
- - ".cache/**"
-
-permissions: {}
-
-env:
- UV_VERSION: "0.6.0"
-
-jobs:
- pytest:
- name: Pytest
- runs-on: ubuntu-latest
- env:
- MUJOCO_GL: egl
- steps:
- - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- lfs: true # Ensure LFS files are pulled
- persist-credentials: false
-
- - name: Install apt dependencies
- # portaudio19-dev is needed to install pyaudio
- run: |
- sudo apt-get update && \
- sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
-
- - name: Install uv and python
- uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
- with:
- enable-cache: true
- version: ${{ env.UV_VERSION }}
- python-version: "3.10"
-
- - name: Install lerobot (all extras)
- run: uv sync --all-extras
-
- - name: Test with pytest
- run: |
- uv run pytest tests -v --cov=./src/lerobot --durations=0 \
- -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
- -W ignore::UserWarning:torch.utils.data.dataloader:558 \
- -W ignore::UserWarning:gymnasium.utils.env_checker:247 \
- && rm -rf tests/outputs outputs
-
- pytest-minimal:
- name: Pytest (minimal install)
- runs-on: ubuntu-latest
- env:
- MUJOCO_GL: egl
- steps:
- - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- lfs: true # Ensure LFS files are pulled
- persist-credentials: false
-
- - name: Install apt dependencies
- run: sudo apt-get update && sudo apt-get install -y ffmpeg
-
- - name: Install uv and python
- uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
- with:
- enable-cache: true
- version: ${{ env.UV_VERSION }}
- python-version: "3.10"
-
- - name: Install lerobot
- run: uv sync --extra "test"
-
- - name: Test with pytest
- run: |
- uv run pytest tests -v --cov=./src/lerobot --durations=0 \
- -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
- -W ignore::UserWarning:torch.utils.data.dataloader:558 \
- -W ignore::UserWarning:gymnasium.utils.env_checker:247 \
- && rm -rf tests/outputs outputs
-
- end-to-end:
- name: End-to-end
- runs-on: ubuntu-latest
- env:
- MUJOCO_GL: egl
- steps:
- - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- lfs: true # Ensure LFS files are pulled
- persist-credentials: false
-
- - name: Install apt dependencies
- # portaudio19-dev is needed to install pyaudio
- run: |
- sudo apt-get update && \
- sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
-
- - name: Install uv and python
- uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
- with:
- enable-cache: true
- version: ${{ env.UV_VERSION }}
- python-version: "3.10"
-
- - name: Install lerobot (all extras)
- run: |
- uv venv
- uv sync --all-extras
-
- - name: venv
- run: |
- echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV
-
- - name: Test end-to-end
- run: |
- make test-end-to-end \
- && rm -rf outputs
diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml
new file mode 100644
index 000000000..902074a83
--- /dev/null
+++ b/.github/workflows/unbound_deps_tests.yml
@@ -0,0 +1,183 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This workflow handles full testing with unboud dependencies versions.
+name: Unbound Dependency Tests
+
+on:
+ # Allows running this workflow manually from the Actions tab
+ workflow_dispatch:
+
+ # Run on the 1st and 15th of every month at 09:00 UTC
+ schedule:
+ - cron: '0 2 1,15 * *'
+
+permissions:
+ contents: read
+
+# Sets up the environment variables
+env:
+ UV_VERSION: "0.8.0"
+ PYTHON_VERSION: "3.10"
+ DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
+
+# Ensures that only the latest action is built, canceling older runs.
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+
+ # This job runs the E2E tests + pytest with all unbound extras
+ full-tests:
+ name: Full Unbound Tests
+ runs-on: ubuntu-latest
+ env:
+ MUJOCO_GL: egl
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ persist-credentials: false
+
+ - name: Install apt dependencies
+ run: |
+ sudo apt-get update && sudo apt-get install -y build-essential \
+ git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \
+ speech-dispatcher libgeos-dev portaudio19-dev
+
+ - name: Setup uv and Python
+ uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ enable-cache: true
+ version: ${{ env.UV_VERSION }}
+ python-version: ${{ env.PYTHON_VERSION }}
+
+ - name: Unbound dependencies
+ run: |
+ sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml
+ echo "Dependencies unbound:" && cat pyproject.toml
+
+ - name: Install lerobot with all extras
+ run: uv sync --all-extras
+
+ - name: Run pytest (all extras)
+ run: uv run pytest tests -vv
+
+ - name: Run end-to-end tests
+ run: uv run make test-end-to-end
+
+ # This job builds a GPU enabled image for testing
+ build-and-push-docker:
+ name: Build and Push Docker
+ runs-on:
+ group: aws-general-8-plus
+ outputs:
+ image_tag: ${{ env.DOCKER_IMAGE_NAME }}
+ env:
+ GITHUB_REF: ${{ github.ref }}
+ steps:
+ - name: Install Git LFS
+ run: |
+ sudo apt-get update
+ sudo apt-get install git-lfs
+ git lfs install
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ persist-credentials: false
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ cache-binary: false
+ - name: Login to Docker Hub
+ uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
+ with:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ - name: Build and push Docker image
+ uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
+ with:
+ context: .
+ file: ./docker/Dockerfile.internal
+ push: true
+ tags: ${{ env.DOCKER_IMAGE_NAME }}
+ build-args: |
+ UNBOUND_DEPS=true
+
+ # This job runs pytest with all unbound extras in a GPU enabled host
+ # It runs everytime a test image is created
+ gpu-tests:
+ name: GPU Unbound Tests
+ needs: [build-and-push-docker]
+ runs-on:
+ group: aws-g6-4xlarge-plus
+ env:
+ HF_HOME: /home/user_lerobot/.cache/huggingface
+ HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
+ TORCH_HOME: /home/user_lerobot/.cache/torch
+ TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
+ container:
+ image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
+ options: --gpus all --shm-size "16gb"
+ credentials:
+ username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: /lerobot
+ steps:
+ - name: Run pytest on GPU
+ run: pytest tests -vv
+ - name: Run end-to-end tests
+ run: make test-end-to-end
+
+ # This job deletes the test image recently created
+ # It runs everytime after the gpu-tests have finished
+ delete-unbound-image:
+ name: Delete Unbound Image
+ needs: [gpu-tests, build-and-push-docker]
+ if: always() && needs.build-and-push-docker.result == 'success'
+ runs-on: ubuntu-latest
+ steps:
+ - name: Get Docker Hub Token and Delete Image
+ # zizmor: ignore[template-injection]
+ run: |
+ IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1)
+ IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2)
+
+ echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG"
+
+ TOKEN=$(curl -s -H "Content-Type: application/json" \
+ -X POST \
+ -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \
+ https://hub.docker.com/v2/users/login/ | jq -r .token)
+
+ if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then
+ echo "::error::Failed to get Docker Hub token."
+ exit 1
+ fi
+
+ HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \
+ -H "Authorization: JWT ${TOKEN}" \
+ -X DELETE \
+ https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/)
+
+ if [ "$HTTP_RESPONSE" -eq 204 ]; then
+ echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG"
+ else
+ echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE"
+ exit 1
+ fi
diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml
deleted file mode 100644
index 32665930b..000000000
--- a/.github/workflows/upload_pr_documentation.yml
+++ /dev/null
@@ -1,16 +0,0 @@
-name: Upload PR Documentation
-
-on: # zizmor: ignore[dangerous-triggers] We follow the same pattern as in Transformers
- workflow_run:
- workflows: [ "Build PR Documentation" ]
- types:
- - completed
-
-jobs:
- build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers
- uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
- with:
- package_name: lerobot
- secrets:
- hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
- comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
diff --git a/.gitignore b/.gitignore
index 4ab886933..b47e22cbf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,164 +12,168 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Dev scripts
-.dev
-
-# Logging
-logs
-tmp
-wandb
-
-# Data
-data
-outputs
-
-# Apple
-.DS_Store
-
-# VS Code
-.vscode
-.devcontainer
-
-# HPC
-nautilus/*.yaml
-*.key
-
-# Slurm
-sbatch*.sh
-
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-pip-wheel-metadata/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# uv/poetry lock files
-poetry.lock
-uv.lock
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-!tests/artifacts
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-
-# Ignore .cache
-.cache/*
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-.pybuilder/
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
+### Environments & Dependencies ###
.env
.venv
env/
venv/
env.bak/
venv.bak/
+.python-version
+__pypackages__/
+node_modules/
-# Spyder project settings
+# Lock files
+poetry.lock
+uv.lock
+Pipfile.lock
+
+### Build & Distribution ###
+build/
+dist/
+sdist/
+wheels/
+downloads/
+eggs/
+.eggs/
+parts/
+var/
+pip-wheel-metadata/
+share/python-wheels/
+develop-eggs/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+lib/
+lib64/
+
+# PyInstaller
+*.manifest
+*.spec
+
+### Compiled & Cached Files ###
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+*.sage.py
+.cache/
+.ruff_cache/
+.mypy_cache/
+.pyre/
+.pytype/
+cython_debug/
+
+### Testing & Coverage ###
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.pytest_cache/
+.hypothesis/
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+!tests/artifacts
+
+### Logs & Temporary Files ###
+logs/
+tmp/
+*.log
+pip-log.txt
+pip-delete-this-directory.txt
+celerybeat-schedule
+celerybeat.pid
+
+### IDE & Editor Config ###
+# VS Code
+.vscode/
+.devcontainer/
+
+# JetBrains / PyCharm
+.idea/
+
+# Spyder
.spyderproject
.spyproject
-# Rope project settings
+# Rope
.ropeproject
-# mkdocs documentation
+# Vim
+*.swp
+
+# Other
+*~
+
+### OS Specific ###
+# macOS
+.DS_Store
+
+# Windows
+Thumbs.db
+
+### Framework & Tool Specific ###
+
+.Python
+
+# Django
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask
+instance/
+.webassets-cache
+
+# Scrapy
+.scrapy
+
+# Jupyter
+.ipynb_checkpoints/
+profile_default/
+ipython_config.py
+
+# Sphinx
+docs/_build/
+
+# MkDocs
/site
+# PyBuilder
+.pybuilder/
+target/
+
# mypy
-.mypy_cache/
.dmypy.json
dmypy.json
-# Pyre type checker
-.pyre/
+### HPC & Slurm ###
+nautilus/*.yaml
+*.key
+sbatch*.sh
-# pytype static type analyzer
-.pytype/
+### Miscellaneous ###
+# W&B
+wandb/
-# Cython debug symbols
-cython_debug/
+# Dev scripts
+.dev/
+
+# Data folders
+data/
+outputs/
+
+# Translations
+*.mo
+*.pot
+
+# Dev folders
+.cache/*
+*.stl
+*.urdf
+*.xml
+*.part
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e1f971d39..7f5beff80 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-exclude: "tests/artifacts/.*\\.safetensors$"
default_language_version:
python: python3.10
+
+exclude: "tests/artifacts/.*\\.safetensors$"
+
repos:
##### Meta #####
- repo: meta
@@ -22,12 +24,12 @@ repos:
- id: check-useless-excludes
- id: check-hooks-apply
-
- ##### Style / Misc. #####
+ ##### General Code Quality & Formatting #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
+ args: ['--maxkb=1024']
- id: debug-statements
- id: check-merge-conflict
- id: check-case-conflict
@@ -36,8 +38,15 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.12.4
+ hooks:
+ - id: ruff-format
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+
- repo: https://github.com/adhtruong/mirrors-typos
- rev: v1.33.1
+ rev: v1.34.0
hooks:
- id: typos
args: [--force-exclude]
@@ -46,14 +55,16 @@ repos:
rev: v3.20.0
hooks:
- id: pyupgrade
+ args: [--py310-plus]
- - repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.11.13
+ ##### Markdown Quality #####
+ - repo: https://github.com/rbubley/mirrors-prettier
+ rev: v3.6.2
hooks:
- - id: ruff
- args: [--fix]
- - id: ruff-format
-
+ - id: prettier
+ name: Format Markdown with Prettier
+ types_or: [markdown, mdx]
+ args: [--prose-wrap=preserve]
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
@@ -62,13 +73,36 @@ repos:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
- rev: v1.9.0
+ rev: v1.11.0
hooks:
- id: zizmor
- repo: https://github.com/PyCQA/bandit
- rev: 1.8.3
+ rev: 1.8.6
hooks:
- id: bandit
args: ["-c", "pyproject.toml"]
additional_dependencies: ["bandit[toml]"]
+
+ # TODO(Steven): Uncomment when ready to use
+ ##### Static Analysis & Typing #####
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v1.16.0
+ hooks:
+ - id: mypy
+ args: [--config-file=pyproject.toml]
+ exclude: ^(examples|benchmarks|tests)/
+
+ ##### Docstring Checks #####
+ # - repo: https://github.com/akaihola/darglint2
+ # rev: v1.8.2
+ # hooks:
+ # - id: darglint2
+ # args: ["--docstring-style", "google", "-v", "2"]
+ # exclude: ^tests/.*$
+
+ # - repo: https://github.com/econchick/interrogate
+ # rev: 1.7.0
+ # hooks:
+ # - id: interrogate
+ # args: ["-vv", "--config=pyproject.toml"]
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index 04a052753..c0fdac843 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -1,4 +1,3 @@
-
# Contributor Covenant Code of Conduct
## Our Pledge
@@ -18,23 +17,23 @@ diverse, inclusive, and healthy community.
Examples of behavior that contributes to a positive environment for our
community include:
-* Demonstrating empathy and kindness toward other people
-* Being respectful of differing opinions, viewpoints, and experiences
-* Giving and gracefully accepting constructive feedback
-* Accepting responsibility and apologizing to those affected by our mistakes,
+- Demonstrating empathy and kindness toward other people
+- Being respectful of differing opinions, viewpoints, and experiences
+- Giving and gracefully accepting constructive feedback
+- Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
-* Focusing on what is best not just for us as individuals, but for the overall
+- Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
-* The use of sexualized language or imagery, and sexual attention or advances of
+- The use of sexualized language or imagery, and sexual attention or advances of
any kind
-* Trolling, insulting or derogatory comments, and personal or political attacks
-* Public or private harassment
-* Publishing others' private information, such as a physical or email address,
+- Trolling, insulting or derogatory comments, and personal or political attacks
+- Public or private harassment
+- Publishing others' private information, such as a physical or email address,
without their explicit permission
-* Other conduct which could reasonably be considered inappropriate in a
+- Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a354e1346..369af602b 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -15,10 +15,11 @@ Whichever way you choose to contribute, please be mindful to respect our
## You can contribute in so many ways!
Some of the ways you can contribute to 🤗 LeRobot:
-* Fixing outstanding issues with the existing code.
-* Implementing new models, datasets or simulation environments.
-* Contributing to the examples or to the documentation.
-* Submitting issues related to bugs or desired new features.
+
+- Fixing outstanding issues with the existing code.
+- Implementing new models, datasets or simulation environments.
+- Contributing to the examples or to the documentation.
+- Submitting issues related to bugs or desired new features.
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co).
@@ -40,24 +41,26 @@ already reported** (use the search bar on Github under Issues).
Did not find it? :( So we can act quickly on it, please follow these steps:
-* Include your **OS type and version**, the versions of **Python** and **PyTorch**.
-* A short, self-contained, code snippet that allows us to reproduce the bug in
+- Include your **OS type and version**, the versions of **Python** and **PyTorch**.
+- A short, self-contained, code snippet that allows us to reproduce the bug in
less than 30s.
-* The full traceback if an exception is raised.
-* Attach any other additional information, like screenshots, you think may help.
+- The full traceback if an exception is raised.
+- Attach any other additional information, like screenshots, you think may help.
### Do you want a new feature?
A good feature request addresses the following points:
1. Motivation first:
-* Is it related to a problem/frustration with the library? If so, please explain
+
+- Is it related to a problem/frustration with the library? If so, please explain
why. Providing a code snippet that demonstrates the problem is best.
-* Is it related to something you would need for a project? We'd love to hear
+- Is it related to something you would need for a project? We'd love to hear
about it!
-* Is it something you worked on and think could benefit the community?
+- Is it something you worked on and think could benefit the community?
Awesome! Tell us what problem it solved for you.
-2. Write a *paragraph* describing the feature.
+
+2. Write a _paragraph_ describing the feature.
3. Provide a **code snippet** that demonstrates its future use.
4. In case this is related to a paper, please attach a link.
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
@@ -74,12 +77,15 @@ environments ([aloha](https://github.com/huggingface/gym-aloha),
and follow the same api design.
When implementing a new dataset loadable with LeRobotDataset follow these steps:
+
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
+
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
+
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
@@ -133,11 +139,13 @@ Follow these steps to start contributing:
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda or miniconda:
+
```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
```
If you're using `uv`, it can manage python versions so you can instead do:
+
```bash
uv venv --python 3.10 && source .venv/bin/activate
```
@@ -145,11 +153,13 @@ Follow these steps to start contributing:
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
using `poetry`
+
```bash
poetry sync --extras "dev test"
```
using `uv`
+
```bash
uv sync --extra dev --extra test
```
@@ -157,43 +167,48 @@ Follow these steps to start contributing:
You can also install the project with all its dependencies (including environments):
using `poetry`
+
```bash
poetry sync --all-extras
```
using `uv`
+
```bash
uv sync --all-extras
```
- > **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
+ > **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they _will_ be tested in the CI. In general, we advise you to install everything and test locally before pushing.
Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
The equivalent of `pip install some-package`, would just be:
using `poetry`
+
```bash
poetry add some-package
```
using `uv`
+
```bash
uv add some-package
```
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
using `poetry`
+
```bash
poetry lock
```
using `uv`
+
```bash
uv lock
```
-
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
@@ -211,11 +226,13 @@ Follow these steps to start contributing:
automatically as Git commit hooks.
Install `pre-commit` hooks:
+
```bash
pre-commit install
```
You can run these hooks whenever you need on staged files with:
+
```bash
pre-commit
```
@@ -229,6 +246,7 @@ Follow these steps to start contributing:
```
Note, if you already committed some changes that have a wrong formatting, you can use:
+
```bash
pre-commit run --all-files
```
@@ -249,16 +267,15 @@ Follow these steps to start contributing:
git push -u origin a-descriptive-name-for-my-changes
```
-6. Once you are satisfied (**and the checklist below is happy too**), go to the
+7. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
-7. It's ok if maintainers ask you for changes. It happens to core contributors
+8. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.
-
### Checklist
1. The title of your pull request should be a summary of its contribution;
@@ -277,18 +294,21 @@ An extensive test suite is included to test the library behavior and several exa
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
+
```bash
brew install git-lfs
git lfs install
```
On Ubuntu:
+
```bash
sudo apt-get install git-lfs
git lfs install
```
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
+
```bash
git lfs pull
```
@@ -300,6 +320,5 @@ repository, here's how to run tests with `pytest` for the library:
python -m pytest -sv ./tests
```
-
You can specify a smaller set of tests in order to test only the feature
you're working on.
diff --git a/Makefile b/Makefile
index ca1495fac..fbe8a5bae 100644
--- a/Makefile
+++ b/Makefile
@@ -26,11 +26,11 @@ export PATH := $(dir $(PYTHON_PATH)):$(PATH)
DEVICE ?= cpu
-build-cpu:
- docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
+build-user:
+ docker build -f docker/Dockerfile.user -t lerobot-user .
-build-gpu:
- docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
+build-internal:
+ docker build -f docker/Dockerfile.internal -t lerobot-internal .
test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
@@ -44,7 +44,7 @@ test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval
test-act-ete-train:
- python -m lerobot.scripts.train \
+ lerobot-train \
--policy.type=act \
--policy.dim_model=64 \
--policy.n_action_steps=20 \
@@ -68,12 +68,12 @@ test-act-ete-train:
--output_dir=tests/outputs/act/
test-act-ete-train-resume:
- python -m lerobot.scripts.train \
+ lerobot-train \
--config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \
--resume=true
test-act-ete-eval:
- python -m lerobot.scripts.eval \
+ lerobot-eval \
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=aloha \
@@ -82,7 +82,7 @@ test-act-ete-eval:
--eval.batch_size=1
test-diffusion-ete-train:
- python -m lerobot.scripts.train \
+ lerobot-train \
--policy.type=diffusion \
--policy.down_dims='[64,128,256]' \
--policy.diffusion_step_embed_dim=32 \
@@ -106,7 +106,7 @@ test-diffusion-ete-train:
--output_dir=tests/outputs/diffusion/
test-diffusion-ete-eval:
- python -m lerobot.scripts.eval \
+ lerobot-eval \
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=pusht \
@@ -115,7 +115,7 @@ test-diffusion-ete-eval:
--eval.batch_size=1
test-tdmpc-ete-train:
- python -m lerobot.scripts.train \
+ lerobot-train \
--policy.type=tdmpc \
--policy.device=$(DEVICE) \
--policy.push_to_hub=false \
@@ -137,7 +137,7 @@ test-tdmpc-ete-train:
--output_dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval:
- python -m lerobot.scripts.eval \
+ lerobot-eval \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=xarm \
@@ -148,7 +148,7 @@ test-tdmpc-ete-eval:
test-smolvla-ete-train:
- python -m lerobot.scripts.train \
+ lerobot-train \
--policy.type=smolvla \
--policy.n_action_steps=20 \
--policy.chunk_size=20 \
@@ -171,7 +171,7 @@ test-smolvla-ete-train:
--output_dir=tests/outputs/smolvla/
test-smolvla-ete-eval:
- python -m lerobot.scripts.eval \
+ lerobot-eval \
--policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=aloha \
diff --git a/README.md b/README.md
index 153a3a215..357e62cc1 100644
--- a/README.md
+++ b/README.md
@@ -1,48 +1,58 @@
Meet the updated SO100, the SO-101 – Just €114 per arm!
Train it in minutes with a few simple moves on your laptop.
@@ -54,7 +64,7 @@
Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!
Check out the LeKiwi tutorial and bring your robot to life on wheels.
-
+
@@ -77,9 +87,9 @@
-
-
-
+
+
+
ACT policy on ALOHA env
@@ -88,61 +98,97 @@
-### Acknowledgment
-
-- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
-- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
-- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
-- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
-- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
-- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
-
-
## Installation
-Download our source code:
-```bash
-git clone https://github.com/huggingface/lerobot.git
-cd lerobot
-```
+LeRobot works with Python 3.10+ and PyTorch 2.2+.
+
+### Environment Setup
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
+
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
When using `miniconda`, install `ffmpeg` in your environment:
+
```bash
conda install ffmpeg -c conda-forge
```
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
-> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
-> ```bash
-> conda install ffmpeg=7.1.1 -c conda-forge
-> ```
-> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
+>
+> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
+>
+> ```bash
+> conda install ffmpeg=7.1.1 -c conda-forge
+> ```
+>
+> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
+
+### Install LeRobot 🤗
+
+#### From Source
+
+First, clone the repository and navigate into the directory:
+
+```bash
+git clone https://github.com/huggingface/lerobot.git
+cd lerobot
+```
+
+Then, install the library in editable mode. This is useful if you plan to contribute to the code.
-Install 🤗 LeRobot:
```bash
pip install -e .
```
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
-`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
+> `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
+
- [aloha](https://github.com/huggingface/gym-aloha)
- [xarm](https://github.com/huggingface/gym-xarm)
- [pusht](https://github.com/huggingface/gym-pusht)
For instance, to install 🤗 LeRobot with aloha and pusht, use:
+
```bash
pip install -e ".[aloha, pusht]"
```
+### Installation from PyPI
+
+**Core Library:**
+Install the base package with:
+
+```bash
+pip install lerobot
+```
+
+_This installs only the default dependencies._
+
+**Extra Features:**
+To install additional functionality, use one of the following:
+
+```bash
+pip install 'lerobot[all]' # All available features
+pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
+pip install 'lerobot[feetech]' # Feetech motor support
+```
+
+_Replace `[...]` with your desired features._
+
+**Available Tags:**
+For a full list of optional dependencies, see:
+https://pypi.org/project/lerobot/
+
+### Weights & Biases
+
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
+
```bash
wandb login
```
@@ -151,37 +197,37 @@ wandb login
### Visualize datasets
-Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
+Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
+
```bash
-python -m lerobot.scripts.visualize_dataset \
+lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0
```
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
+
```bash
-python -m lerobot.scripts.visualize_dataset \
+lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--local-files-only 1 \
--episode-index 0
```
-
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
-
-Our script can also visualize datasets stored on a distant server. See `python -m lerobot.scripts.visualize_dataset --help` for more instructions.
+Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions.
### The `LeRobotDataset` format
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.
-A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
+A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) for more details on `delta_timestamps`.
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.
@@ -200,213 +246,98 @@ dataset attributes:
│ ├ timestamp (float32): timestamp in the episode
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
│ └ index (int64): general index in the whole dataset
- ├ episode_data_index: contains 2 tensors with the start and end indices of each episode
- │ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0
- │ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,)
- ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
- │ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
- │ ...
- ├ info: a dictionary of metadata on the dataset
- │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
- │ ├ fps (float): frame per second the dataset is recorded/synchronized to
- │ ├ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
- │ └ encoding (dict): if video, this documents the main options that were used with ffmpeg to encode the videos
- ├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
- └ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
+ ├ meta: a LeRobotDatasetMetadata object containing:
+ │ ├ info: a dictionary of metadata on the dataset
+ │ │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
+ │ │ ├ fps (int): frame per second the dataset is recorded/synchronized to
+ │ │ ├ features (dict): all features contained in the dataset with their shapes and types
+ │ │ ├ total_episodes (int): total number of episodes in the dataset
+ │ │ ├ total_frames (int): total number of frames in the dataset
+ │ │ ├ robot_type (str): robot type used for recording
+ │ │ ├ data_path (str): formattable string for the parquet files
+ │ │ └ video_path (str): formattable string for the video files (if using videos)
+ │ ├ episodes: a DataFrame containing episode metadata with columns:
+ │ │ ├ episode_index (int): index of the episode
+ │ │ ├ tasks (list): list of tasks for this episode
+ │ │ ├ length (int): number of frames in this episode
+ │ │ ├ dataset_from_index (int): start index of this episode in the dataset
+ │ │ └ dataset_to_index (int): end index of this episode in the dataset
+ │ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
+ │ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
+ │ │ └ ...
+ │ └ tasks: a DataFrame containing task information with task names as index and task_index as values
+ ├ root (Path): local directory where the dataset is stored
+ ├ image_transforms (Callable): optional image transformations to apply to visual modalities
+ └ delta_timestamps (dict): optional delta timestamps for temporal queries
```
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
+
- hf_dataset stored using Hugging Face datasets library serialization to parquet
- videos are stored in mp4 format to save space
- metadata are stored in plain json/jsonl files
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
-### Evaluate a pretrained policy
-
-Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
-
-We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
-```bash
-python -m lerobot.scripts.eval \
- --policy.path=lerobot/diffusion_pusht \
- --env.type=pusht \
- --eval.batch_size=10 \
- --eval.n_episodes=10 \
- --policy.use_amp=false \
- --policy.device=cuda
-```
-
-Note: After training your own policy, you can re-evaluate the checkpoints with:
-
-```bash
-python -m lerobot.scripts.eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model
-```
-
-See `python -m lerobot.scripts.eval --help` for more instructions.
-
-### Train your own policy
-
-Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
-
-To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`.
-
-A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
-
-
-
-Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions.
-
#### Reproduce state-of-the-art (SOTA)
We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances.
You can reproduce their training by loading the config from their run. Simply running:
+
```bash
-python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
+lerobot-train --config_path=lerobot/diffusion_pusht
```
+
reproduces SOTA results for Diffusion Policy on the PushT task.
## Contribute
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
-
-
-
### Add a pretrained policy
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
+
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility.
To upload these to the hub, run the following:
+
```bash
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
```
-See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
+See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
+### Acknowledgment
-### Improve your code with profiling
-
-An example of a code snippet to profile the evaluation of a policy:
-```python
-from torch.profiler import profile, record_function, ProfilerActivity
-
-def trace_handler(prof):
- prof.export_chrome_trace(f"tmp/trace_schedule_{prof.step_num}.json")
-
-with profile(
- activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
- schedule=torch.profiler.schedule(
- wait=2,
- warmup=2,
- active=3,
- ),
- on_trace_ready=trace_handler
-) as prof:
- with record_function("eval_policy"):
- for i in range(num_episodes):
- prof.step()
- # insert code to profile, potentially whole body of eval_policy function
-```
+- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
+- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
+- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
+- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
+- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
+- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
## Citation
If you want, you can cite this work with:
+
```bibtex
@misc{cadene2024lerobot,
- author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
+ author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
}
```
-Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below:
-- [SmolVLA](https://arxiv.org/abs/2506.01844)
-```bibtex
-@article{shukor2025smolvla,
- title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics},
- author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi},
- journal={arXiv preprint arXiv:2506.01844},
- year={2025}
-}
-```
-
-- [Diffusion Policy](https://diffusion-policy.cs.columbia.edu)
-```bibtex
-@article{chi2024diffusionpolicy,
- author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
- title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
- journal = {The International Journal of Robotics Research},
- year = {2024},
-}
-```
-- [ACT or ALOHA](https://tonyzhaozh.github.io/aloha)
-```bibtex
-@article{zhao2023learning,
- title={Learning fine-grained bimanual manipulation with low-cost hardware},
- author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea},
- journal={arXiv preprint arXiv:2304.13705},
- year={2023}
-}
-```
-
-- [TDMPC](https://www.nicklashansen.com/td-mpc/)
-
-```bibtex
-@inproceedings{Hansen2022tdmpc,
- title={Temporal Difference Learning for Model Predictive Control},
- author={Nicklas Hansen and Xiaolong Wang and Hao Su},
- booktitle={ICML},
- year={2022}
-}
-```
-
-- [VQ-BeT](https://sjlee.cc/vq-bet/)
-```bibtex
-@article{lee2024behavior,
- title={Behavior generation with latent actions},
- author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
- journal={arXiv preprint arXiv:2403.03181},
- year={2024}
-}
-```
-
-
-- [HIL-SERL](https://hil-serl.github.io/)
-```bibtex
-@Article{luo2024hilserl,
-title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
-author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine},
-year={2024},
-eprint={2410.21845},
-archivePrefix={arXiv},
-primaryClass={cs.RO}
-}
-```
## Star History
[](https://star-history.com/#huggingface/lerobot&Timeline)
+
+```
+
+```
diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md
index daa3e1f48..490a4b495 100644
--- a/benchmarks/video/README.md
+++ b/benchmarks/video/README.md
@@ -1,28 +1,32 @@
# Video benchmark
-
## Questions
+
What is the optimal trade-off between:
+
- maximizing loading time with random access,
- minimizing memory space on disk,
- maximizing success rate of policies,
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
How to encode videos?
+
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
How to decode videos?
+
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
-
## Variables
+
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
+
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
@@ -34,8 +38,9 @@ Note: The datasets used for this benchmark need to be image datasets, not video
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
### Encoding parameters
+
| parameter | values |
-|-------------|--------------------------------------------------------------|
+| ----------- | ------------------------------------------------------------ |
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
| **pix_fmt** | `yuv444p`, `yuv420p` |
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
@@ -44,19 +49,23 @@ We might revisit this benchmark and find better settings if we train our policie
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
+
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
### Decoding parameters
+
**Decoder**
We tested two video decoding backends from torchvision:
+
- `pyav`
- `video_reader` (requires to build torchvision from source)
**Requested timestamps**
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
+
- `1_frame`: 1 frame,
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
@@ -64,12 +73,13 @@ This of course is affected by the `-g` parameter during encoding, which specifie
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
+
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
-
## Metrics
+
**Data compression ratio (lower is better)**
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
@@ -87,18 +97,18 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
+
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
-
-
## How the benchmark works
+
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
@@ -110,15 +120,18 @@ Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv ta
These are then all concatenated to a single table ready for analysis.
## Caveats
+
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
Additional encoding parameters exist that are not included in this benchmark. In particular:
+
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
+
- `torchaudio`
- `ffmpegio`
- `decord`
@@ -127,16 +140,17 @@ Similarly on the decoding side, other decoders exist but are not implemented in
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
-
## Install
+
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
-
## Adding a video decoder
+
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
You can easily add a new decoder to benchmark by adding it to this function in the script:
+
```diff
def decode_video_frames(
video_path: str,
@@ -156,9 +170,10 @@ def decode_video_frames(
raise NotImplementedError(backend)
```
-
## Example
+
For a quick run, you can try these parameters:
+
```bash
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
@@ -176,11 +191,12 @@ python benchmark/video/run_video_benchmark.py \
--save-frames 0
```
-
## Results
### Reproduce
+
We ran the benchmark with the following parameters:
+
```bash
# h264 and h265 encodings
python benchmark/video/run_video_benchmark.py \
@@ -221,9 +237,10 @@ python benchmark/video/run_video_benchmark.py \
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
-
### Parameters selected for LeRobotDataset
+
Considering these results, we chose what we think is the best set of encoding parameter:
+
- vcodec: `libsvtav1`
- pix-fmt: `yuv420p`
- g: `2`
@@ -236,7 +253,7 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
-|------------------------------------|------------|---------|-----------|-----------|-----------|
+| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
@@ -245,7 +262,7 @@ These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
-|------------------------------------|---------|---------|----------|---------|-----------|
+| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
@@ -254,7 +271,7 @@ These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
-|------------------------------------|----------|----------|--------------|----------|-----------|--------------|
+| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
diff --git a/src/lerobot/utils/benchmark.py b/benchmarks/video/benchmark.py
similarity index 99%
rename from src/lerobot/utils/benchmark.py
rename to benchmarks/video/benchmark.py
index 4b08e6f6d..d9e5e62bb 100644
--- a/src/lerobot/utils/benchmark.py
+++ b/benchmarks/video/benchmark.py
@@ -46,11 +46,13 @@ class TimeBenchmark(ContextDecorator):
benchmark = TimeBenchmark()
+
def context_manager_example():
with benchmark:
time.sleep(0.01)
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
+
threads = []
for _ in range(3):
t1 = threading.Thread(target=context_manager_example)
diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py
index bababf636..9f34b2273 100644
--- a/benchmarks/video/run_video_benchmark.py
+++ b/benchmarks/video/run_video_benchmark.py
@@ -35,12 +35,13 @@ import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
+from benchmarks.video.benchmark import TimeBenchmark
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
decode_video_frames_torchvision,
encode_video_frames,
)
-from lerobot.utils.benchmark import TimeBenchmark
+from lerobot.utils.constants import OBS_IMAGE
BASE_ENCODING = OrderedDict(
[
@@ -108,7 +109,8 @@ def save_decoded_frames(
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
- ep_num_images = dataset.episode_data_index["to"][0].item()
+ episode_index = 0
+ ep_num_images = dataset.meta.episodes["length"][episode_index]
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
return
@@ -116,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
- img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
+ img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
@@ -265,7 +267,8 @@ def benchmark_encoding_decoding(
overwrite=True,
)
- ep_num_images = dataset.episode_data_index["to"][0].item()
+ episode_index = 0
+ ep_num_images = dataset.meta.episodes["length"][episode_index]
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal
new file mode 100644
index 000000000..2616cd06c
--- /dev/null
+++ b/docker/Dockerfile.internal
@@ -0,0 +1,93 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This Dockerfile is designed for HuggingFace internal CI environments
+# that require GPU access. It starts from an NVIDIA CUDA base image.
+
+# docker build -f docker/Dockerfile.internal -t lerobot-internal .
+
+# Configure the base image for CI with GPU access
+# TODO(Steven): Bump these versions
+ARG CUDA_VERSION=12.4.1
+ARG OS_VERSION=22.04
+FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
+
+# Define Python version argument
+ARG PYTHON_VERSION=3.10
+
+# Configure environment variables
+ENV DEBIAN_FRONTEND=noninteractive \
+ MUJOCO_GL=egl \
+ PATH=/lerobot/.venv/bin:$PATH \
+ CUDA_VISIBLE_DEVICES=0 \
+ TEST_TYPE=single_gpu \
+ DEVICE=cuda
+
+# Install Python, system dependencies, and uv (as root)
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ software-properties-common build-essential git curl \
+ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
+ libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
+ cmake pkg-config ninja-build \
+ && add-apt-repository -y ppa:deadsnakes/ppa \
+ && apt-get update \
+ && apt-get install -y --no-install-recommends \
+ python${PYTHON_VERSION} \
+ python${PYTHON_VERSION}-venv \
+ python${PYTHON_VERSION}-dev \
+ && curl -LsSf https://astral.sh/uv/install.sh | sh \
+ && mv /root/.local/bin/uv /usr/local/bin/uv \
+ && useradd --create-home --shell /bin/bash user_lerobot \
+ && usermod -aG sudo user_lerobot \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+# Create application directory and set permissions
+WORKDIR /lerobot
+RUN chown -R user_lerobot:user_lerobot /lerobot
+
+# Switch to the non-root user
+USER user_lerobot
+
+# Environment variables for the testing
+ENV HOME=/home/user_lerobot \
+ HF_HOME=/home/user_lerobot/.cache/huggingface \
+ HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
+ TORCH_HOME=/home/user_lerobot/.cache/torch \
+ TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
+
+# Create the virtual environment
+# We use a virtual environment inside the container—even though the container itself \
+# provides isolation—to ensure compatibility with the cluster and to prevent \
+# issues with MuJoCo and OpenGL drivers.
+RUN uv venv --python python${PYTHON_VERSION}
+
+# Install Python dependencies for caching
+COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
+COPY --chown=user_lerobot:user_lerobot src/ src/
+
+ARG UNBOUND_DEPS=false
+
+RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
+ sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
+ echo "Dependencies unbound:" && cat pyproject.toml; \
+ fi
+
+RUN uv pip install --no-cache ".[all]"
+
+# Copy the rest of the application source code
+# Make sure to have the git-LFS files for testing
+COPY --chown=user_lerobot:user_lerobot . .
+
+# Set the default command
+CMD ["/bin/bash"]
diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user
new file mode 100644
index 000000000..c1b284453
--- /dev/null
+++ b/docker/Dockerfile.user
@@ -0,0 +1,79 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This Dockerfile is designed for a lerobot user who wants to
+# experiment with the project. It starts from an Python Slim base image.
+
+# docker build -f docker/Dockerfile.user -t lerobot-user .
+# docker run -it --rm lerobot-user
+
+# Configure the base image
+ARG PYTHON_VERSION=3.10
+FROM python:${PYTHON_VERSION}-slim
+
+# Configure environment variables
+ENV DEBIAN_FRONTEND=noninteractive \
+ MUJOCO_GL=egl \
+ PATH=/lerobot/.venv/bin:$PATH
+
+# Install system dependencies and uv (as root)
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \
+ libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
+ cmake pkg-config ninja-build \
+ && curl -LsSf https://astral.sh/uv/install.sh | sh \
+ && mv /root/.local/bin/uv /usr/local/bin/uv \
+ && useradd --create-home --shell /bin/bash user_lerobot \
+ && usermod -aG sudo user_lerobot \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+# Create application directory and set permissions
+WORKDIR /lerobot
+RUN chown -R user_lerobot:user_lerobot /lerobot
+
+# Switch to the non-root user
+USER user_lerobot
+
+# Environment variables for the testing
+ENV HOME=/home/user_lerobot \
+ HF_HOME=/home/user_lerobot/.cache/huggingface \
+ HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
+ TORCH_HOME=/home/user_lerobot/.cache/torch \
+ TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
+
+# Create the virtual environment
+# We use a virtual environment inside the container—even though the container itself \
+# provides isolation—to closely resemble local development and allow users to \
+# run other Python projects in the same container without dependency conflicts.
+RUN uv venv
+
+# Install Python dependencies for caching
+COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./
+COPY --chown=user_lerobot:user_lerobot src/ src/
+
+ARG UNBOUND_DEPS=false
+
+RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
+ sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
+ echo "Dependencies unbound:" && cat pyproject.toml; \
+ fi
+
+RUN uv pip install --no-cache ".[all]"
+
+# Copy the rest of the application code
+# Make sure to have the git-LFS files for testing
+COPY --chown=user_lerobot:user_lerobot . .
+
+# Set the default command
+CMD ["/bin/bash"]
diff --git a/docker/lerobot-cpu/Dockerfile b/docker/lerobot-cpu/Dockerfile
deleted file mode 100644
index 85c31ac1a..000000000
--- a/docker/lerobot-cpu/Dockerfile
+++ /dev/null
@@ -1,29 +0,0 @@
-# Configure image
-ARG PYTHON_VERSION=3.10
-FROM python:${PYTHON_VERSION}-slim
-
-# Configure environment variables
-ARG PYTHON_VERSION
-ENV DEBIAN_FRONTEND=noninteractive
-ENV MUJOCO_GL="egl"
-ENV PATH="/opt/venv/bin:$PATH"
-
-# Install dependencies and set up Python in a single layer
-RUN apt-get update && apt-get install -y --no-install-recommends \
- build-essential cmake git \
- libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
- speech-dispatcher libgeos-dev \
- && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
- && python -m venv /opt/venv \
- && apt-get clean && rm -rf /var/lib/apt/lists/* \
- && echo "source /opt/venv/bin/activate" >> /root/.bashrc
-
-# Clone repository and install LeRobot in a single layer
-COPY . /lerobot
-WORKDIR /lerobot
-RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
- && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, smolvla]" \
- --extra-index-url https://download.pytorch.org/whl/cpu
-
-# Execute in bash shell rather than python
-CMD ["/bin/bash"]
diff --git a/docker/lerobot-gpu-dev/Dockerfile b/docker/lerobot-gpu-dev/Dockerfile
deleted file mode 100644
index 4d25b2550..000000000
--- a/docker/lerobot-gpu-dev/Dockerfile
+++ /dev/null
@@ -1,68 +0,0 @@
-FROM nvidia/cuda:12.2.2-devel-ubuntu22.04
-
-# Configure image
-ARG PYTHON_VERSION=3.10
-ARG DEBIAN_FRONTEND=noninteractive
-
-# Install apt dependencies
-RUN apt-get update && apt-get install -y --no-install-recommends \
- build-essential cmake \
- git git-lfs openssh-client \
- nano vim less util-linux tree \
- htop atop nvtop \
- sed gawk grep curl wget zip unzip \
- tcpdump sysstat screen tmux \
- libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
- speech-dispatcher portaudio19-dev libgeos-dev \
- python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \
- && apt-get clean && rm -rf /var/lib/apt/lists/*
-
-# Install ffmpeg build dependencies. See:
-# https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu
-# TODO(aliberts): create image to build dependencies from source instead
-RUN apt-get update && apt-get install -y --no-install-recommends \
- autoconf automake yasm \
- libass-dev \
- libfreetype6-dev \
- libgnutls28-dev \
- libunistring-dev \
- libmp3lame-dev \
- libtool \
- libvorbis-dev \
- meson \
- ninja-build \
- pkg-config \
- texinfo \
- yasm \
- zlib1g-dev \
- nasm \
- libx264-dev \
- libx265-dev libnuma-dev \
- libvpx-dev \
- libfdk-aac-dev \
- libopus-dev \
- libsvtav1-dev libsvtav1enc-dev libsvtav1dec-dev \
- libdav1d-dev
-
-# Install gh cli tool
-RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
- && mkdir -p -m 755 /etc/apt/keyrings \
- && wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
- && chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
- && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
- && apt update \
- && apt install gh -y \
- && apt clean && rm -rf /var/lib/apt/lists/*
-
-# Setup `python`
-RUN ln -s /usr/bin/python3 /usr/bin/python
-
-# Install poetry
-RUN curl -sSL https://install.python-poetry.org | python -
-ENV PATH="/root/.local/bin:$PATH"
-RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
-RUN poetry config virtualenvs.create false
-RUN poetry config virtualenvs.in-project true
-
-# Set EGL as the rendering backend for MuJoCo
-ENV MUJOCO_GL="egl"
diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile
deleted file mode 100644
index 746ea29b7..000000000
--- a/docker/lerobot-gpu/Dockerfile
+++ /dev/null
@@ -1,24 +0,0 @@
-FROM nvidia/cuda:12.4.1-base-ubuntu22.04
-
-# Configure environment variables
-ARG PYTHON_VERSION=3.10
-ENV DEBIAN_FRONTEND=noninteractive
-ENV MUJOCO_GL="egl"
-ENV PATH="/opt/venv/bin:$PATH"
-
-# Install dependencies and set up Python in a single layer
-RUN apt-get update && apt-get install -y --no-install-recommends \
- build-essential cmake git \
- libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
- speech-dispatcher libgeos-dev \
- python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
- && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \
- && python -m venv /opt/venv \
- && apt-get clean && rm -rf /var/lib/apt/lists/* \
- && echo "source /opt/venv/bin/activate" >> /root/.bashrc
-
-# Clone repository and install LeRobot in a single layer
-COPY . /lerobot
-WORKDIR /lerobot
-RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \
- && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel, smolvla]"
diff --git a/docs-requirements.txt b/docs-requirements.txt
new file mode 100644
index 000000000..e286ad2bb
--- /dev/null
+++ b/docs-requirements.txt
@@ -0,0 +1,3 @@
+# docs-requirements.txt
+hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main
+watchdog>=6.0.0
diff --git a/docs/README.md b/docs/README.md
index 275fee46b..476eb8dce 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -20,12 +20,13 @@ To generate the documentation, you first have to build it. Several packages are
you can install them with the following command, at the root of the code repository:
```bash
-pip install -e ".[docs]"
+pip install -e . -r docs-requirements.txt
```
You will also need `nodejs`. Please refer to their [installation page](https://nodejs.org/en/download)
---
+
**NOTE**
You only need to generate the documentation to inspect it locally (if you're planning changes and want to
@@ -63,6 +64,7 @@ doc-builder preview lerobot docs/source/
The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
---
+
**NOTE**
The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
@@ -89,6 +91,7 @@ Sections that were moved:
[ Section A ]
```
+
and of course, if you moved it to another file, then:
```
@@ -119,7 +122,6 @@ and objects like True, None or any strings should usually be put in `code`.
Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
-
````
```
# first line of code
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index ea80e8257..36eaea165 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -17,11 +17,35 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
+ - local: async
+ title: Use Async Inference
title: "Tutorials"
+- sections:
+ - local: lerobot-dataset-v3
+ title: Using LeRobotDataset
+ - local: porting_datasets_v3
+ title: Porting Large Datasets
+ title: "Datasets"
- sections:
- local: smolvla
- title: Finetune SmolVLA
+ title: SmolVLA
+ - local: pi0
+ title: π₀ (Pi0)
+ - local: pi05
+ title: π₀.₅ (Pi05)
+ - local: libero
+ title: Using Libero
title: "Policies"
+- sections:
+ - local: introduction_processors
+ title: Introduction to Robot Processors
+ - local: debug_processor_pipeline
+ title: Debug your processor pipeline
+ - local: implement_your_own_processor
+ title: Implement your own processor
+ - local: processors_robots_teleop
+ title: Processors for Robots and Teleoperators
+ title: "Robot Processors"
- sections:
- local: so101
title: SO-101
@@ -31,10 +55,20 @@
title: Koch v1.1
- local: lekiwi
title: LeKiwi
+ - local: hope_jr
+ title: Hope Jr
+ - local: reachy2
+ title: Reachy 2
title: "Robots"
+- sections:
+ - local: phone_teleop
+ title: Phone
+ title: "Teleoperators"
- sections:
- local: notebooks
title: Notebooks
+ - local: feetech
+ title: Updating Feetech Firmware
title: "Resources"
- sections:
- local: contributing
diff --git a/docs/source/async.mdx b/docs/source/async.mdx
new file mode 100644
index 000000000..be10f8baf
--- /dev/null
+++ b/docs/source/async.mdx
@@ -0,0 +1,312 @@
+# Asynchronous Inference
+
+With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
+In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
+**Try async inference with all the policies** supported by LeRobot!
+
+**What you'll learn:**
+
+1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
+2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
+3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
+
+If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
+
+In a nutshell: with _async inference_, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
+This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
+
+---
+
+## Getting started with async inference
+
+You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment.
+
+First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
+
+```shell
+pip install -e ".[async]"
+```
+
+Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
+You can spin up a policy server running:
+
+```shell
+python -m lerobot.async_inference.policy_server \
+ --host=127.0.0.1 \
+ --port=8080
+```
+
+This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
+
+```shell
+python -m lerobot.async_inference.robot_client \
+ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
+ --robot.type=so100_follower \ # ROBOT: your robot type
+ --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
+ --robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
+ --robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
+ --task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
+ --policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
+ --pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
+ --policy_device=mps \ # POLICY: the device to run the policy on, on the server
+ --actions_per_chunk=50 \ # POLICY: the number of actions to output at once
+ --chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
+ --aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
+ --debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
+```
+
+In summary, you need to specify instructions for:
+
+- `SERVER`: the address and port of the policy server
+- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
+- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
+- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
+
+Importantly,
+
+- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
+- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
+- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
+
+## Done! You should see your robot moving around by now 😉
+
+## Async vs. synchronous inference
+
+Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in _idle frames_, frames where the robot awaits idle the policy's output: a new action chunk.
+In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
+With robotics models increasing in size, this problem risks becoming only more severe.
+
+
+
+
+
+ Synchronous inference makes the robot idle while the policy is
+ computing the next chunk of actions.
+
+
+To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
+Crucially, with async inference, the next action chunk is computed _before_ the current one is exhausted, resulting in no idleness.
+Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
+
+
+
+
+
+ Asynchronous inference results in no idleness because the next chunk is
+ computed before the current chunk is exhausted.
+
+
+---
+
+## Start the Policy Server
+
+Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
+Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
+As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
+
+
+
+```bash
+python -m lerobot.async_inference.policy_server \
+ --host=127.0.0.1 \
+ --port=8080
+```
+
+
+
+
+```python
+from lerobot.async_inference.configs import PolicyServerConfig
+from lerobot.async_inference.policy_server import serve
+
+config = PolicyServerConfig(
+ host="localhost",
+ port=8080,
+)
+serve(config)
+```
+
+
+
+
+
+This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
+
+---
+
+## Launch the Robot Client
+
+`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
+The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
+
+
+
+```bash
+python -m lerobot.async_inference.robot_client \
+ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
+ --robot.type=so100_follower \ # ROBOT: your robot type
+ --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
+ --robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
+ --robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
+ --task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
+ --policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
+ --pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
+ --policy_device=mps \ # POLICY: the device to run the policy on, on the server
+ --actions_per_chunk=50 \ # POLICY: the number of actions to output at once
+ --chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
+ --aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
+ --debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
+```
+
+
+
+
+```python
+import threading
+from lerobot.robots.so100_follower import SO100FollowerConfig
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.async_inference.configs import RobotClientConfig
+from lerobot.async_inference.robot_client import RobotClient
+from lerobot.async_inference.helpers import visualize_action_queue_size
+
+# 1. Create the robot instance
+"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
+# these cameras must match the ones expected by the policy
+# check the config.json on the Hub for the policy you are using
+camera_cfg = {
+ "top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
+ "side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
+}
+
+robot_cfg = SO100FollowerConfig(
+ port="/dev/tty.usbmodem585A0076841",
+ id="follower_so100",
+ cameras=camera_cfg
+)
+
+# 3. Create client configuration
+client_cfg = RobotClientConfig(
+ robot=robot_cfg,
+ server_address="localhost:8080",
+ policy_device="mps",
+ policy_type="smolvla",
+ pretrained_name_or_path="fracapuano/smolvla_async",
+ chunk_size_threshold=0.5,
+ actions_per_chunk=50, # make sure this is less than the max actions of the policy
+)
+
+# 4. Create and start client
+client = RobotClient(client_cfg)
+
+# 5. Specify the task
+task = "Don't do anything, stay still"
+
+if client.start():
+ # Start action receiver thread
+ action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
+ action_receiver_thread.start()
+
+ try:
+ # Run the control loop
+ client.control_loop(task)
+ except KeyboardInterrupt:
+ client.stop()
+ action_receiver_thread.join()
+ # (Optionally) plot the action queue size
+ visualize_action_queue_size(client.action_queue_size)
+```
+
+
+
+
+
+The following two parameters are key in every setup:
+
+
+
+
+
Hyperparameter
+
Default
+
What it does
+
+
+
+
+
+ actions_per_chunk
+
+
50
+
+ How many actions the policy outputs at once. Typical values: 10-50.
+
+
+
+
+ chunk_size_threshold
+
+
0.7
+
+ When the queue is ≤ 50% full, the client sends a fresh observation.
+ Value in [0, 1].
+
+
+
+
+
+
+ Different values of `actions_per_chunk` and `chunk_size_threshold` do result
+ in different behaviours.
+
+
+On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
+However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
+
+On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
+This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
+
+We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
+
+### Tuning async inference for your setup
+
+1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
+2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
+3. **Adjust `chunk_size_threshold`**.
+ - Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
+ - We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
+
+
+
+
+
+
+ The action queue size is plotted at runtime when the
+ `--debug-visualize-queue-size` flag is passed, for various levels of
+ `chunk_size_threshold` (`g` in the SmolVLA paper).
+
+
+
+---
+
+## Conclusion
+
+Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
+
+**Key Takeaways:**
+
+- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
+- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
+- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
+- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
+- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
+
+Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
+If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
diff --git a/docs/source/backwardcomp.mdx b/docs/source/backwardcomp.mdx
index 555239170..3366c8ab9 100644
--- a/docs/source/backwardcomp.mdx
+++ b/docs/source/backwardcomp.mdx
@@ -1,26 +1,83 @@
# Backward compatibility
+## Policy Normalization Migration (PR #1452)
+
+**Breaking Change**: LeRobot policies no longer have built-in normalization layers embedded in their weights. Normalization is now handled by external `PolicyProcessorPipeline` components.
+
+### What changed?
+
+| | Before PR #1452 | After PR #1452 |
+| -------------------------- | ------------------------------------------------ | ------------------------------------------------------------ |
+| **Normalization Location** | Embedded in model weights (`normalize_inputs.*`) | External `PolicyProcessorPipeline` components |
+| **Model State Dict** | Contains normalization statistics | **Clean weights only** - no normalization parameters |
+| **Usage** | `policy(batch)` handles everything | `preprocessor(batch)` → `policy(...)` → `postprocessor(...)` |
+
+### Impact on existing models
+
+- Models trained **before** PR #1452 have normalization embedded in their weights
+- These models need migration to work with the new `PolicyProcessorPipeline` system
+- The migration extracts normalization statistics and creates separate processor pipelines
+
+### Migrating old models
+
+Use the migration script to convert models with embedded normalization:
+
+```shell
+python src/lerobot/processor/migrate_policy_normalization.py \
+ --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
+ --push-to-hub \
+ --branch migrated
+```
+
+The script:
+
+1. **Extracts** normalization statistics from model weights
+2. **Creates** external preprocessor and postprocessor pipelines
+3. **Removes** normalization layers from model weights
+4. **Saves** clean model + processor pipelines
+5. **Pushes** to Hub with automatic PR creation
+
+### Using migrated models
+
+```python
+# New usage pattern (after migration)
+from lerobot.policies.factory import make_policy, make_pre_post_processors
+
+# Load model and processors separately
+policy = make_policy(config, ds_meta=dataset.meta)
+preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=config,
+ dataset_stats=dataset.meta.stats
+)
+
+# Process data through pipeline
+processed_batch = preprocessor(raw_batch)
+action = policy.select_action(processed_batch)
+final_action = postprocessor(action)
+```
+
## Hardware API redesign
PR [#777](https://github.com/huggingface/lerobot/pull/777) improves the LeRobot calibration but is **not backward-compatible**. Below is a overview of what changed and how you can continue to work with datasets created before this pull request.
### What changed?
-| | Before PR #777 | After PR #777 |
-| --------------------------------- | ------------------------------------------------- | --------------------------------------------------------------------------- |
-| **Joint range** | Degrees `-180...180°` | **Normalised range** Joints: `–100...100` Gripper: `0...100` |
-| **Zero position (SO100 / SO101)** | Arm fully extended horizontally | **In middle of the range for each joint** |
-| **Boundary handling** | Software safeguards to detect ±180 ° wrap-arounds | No wrap-around logic needed due to mid-range zero |
+| | Before PR #777 | After PR #777 |
+| --------------------------------- | ------------------------------------------------- | ------------------------------------------------------------ |
+| **Joint range** | Degrees `-180...180°` | **Normalised range** Joints: `–100...100` Gripper: `0...100` |
+| **Zero position (SO100 / SO101)** | Arm fully extended horizontally | **In middle of the range for each joint** |
+| **Boundary handling** | Software safeguards to detect ±180 ° wrap-arounds | No wrap-around logic needed due to mid-range zero |
---
### Impact on existing datasets
-* Recorded trajectories created **before** PR #777 will replay incorrectly if loaded directly:
- * Joint angles are offset and incorrectly normalized.
-* Any models directly finetuned or trained on the old data will need their inputs and outputs converted.
+- Recorded trajectories created **before** PR #777 will replay incorrectly if loaded directly:
+ - Joint angles are offset and incorrectly normalized.
+- Any models directly finetuned or trained on the old data will need their inputs and outputs converted.
### Using datasets made with the previous calibration system
+
We provide a migration example script for replaying an episode recorded with the previous calibration here: `examples/backward_compatibility/replay.py`.
Below we take you through the modifications that are done in the example script to make the previous calibration datasets work.
@@ -33,20 +90,31 @@ Below we take you through the modifications that are done in the example script
Let's break this down.
New codebase uses `.pos` suffix for the position observations and we have removed `main_` prefix:
+
+
```python
key = f"{name.removeprefix('main_')}.pos"
```
+
For `"shoulder_lift"` (id = 2), the 0 position is changed by -90 degrees and the direction is reversed compared to old calibration/code.
+
+
```python
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
```
+
+
For `"elbow_flex"` (id = 3), the 0 position is changed by -90 degrees compared to old calibration/code.
+
+
```python
action["elbow_flex.pos"] -= 90
```
+
To use degrees normalization we then set the `--robot.use_degrees` option to `true`.
+
```diff
python examples/backward_compatibility/replay.py \
--robot.type=so101_follower \
@@ -63,6 +131,7 @@ Policies output actions in the same format as the datasets (`torch.Tensors`). Th
To find these transformations, we recommend to first try and and replay an episode of the dataset your policy was trained on using the section above.
Then, add these same transformations on your inference script (shown here in the `record.py` script):
+
```diff
action_values = predict_action(
observation_frame,
diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx
index 313d5a7cd..5c35be0ba 100644
--- a/docs/source/cameras.mdx
+++ b/docs/source/cameras.mdx
@@ -7,11 +7,13 @@ LeRobot offers multiple options for video capture, including phone cameras, buil
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
To find the camera indices of the cameras plugged into your system, run the following script:
+
```bash
-python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras
+lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
```
The output will look something like this if you have two cameras connected:
+
```
--- Detected Cameras ---
Camera #0:
@@ -31,7 +33,6 @@ Camera #0:
> [!WARNING]
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
-
## Use Cameras
Below are two examples, demonstrating how to work with the API.
@@ -39,10 +40,10 @@ Below are two examples, demonstrating how to work with the API.
- **Asynchronous frame capture** using an OpenCV-based camera
- **Color and depth capture** using an Intel RealSense camera
-
+
```python
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.cameras.opencv.camera_opencv import OpenCVCamera
@@ -70,10 +71,12 @@ try:
finally:
camera.disconnect()
```
+
+
```python
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig
from lerobot.cameras.realsense.camera_realsense import RealSenseCamera
@@ -103,15 +106,18 @@ try:
finally:
camera.disconnect()
```
+
+
-
## Use your phone
+
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
+
- Ensure your Mac is running macOS 13 or later, and your iPhone is on iOS 16 or later.
- Sign in both devices with the same Apple ID.
- Connect your devices with a USB cable or turn on Wi-Fi and Bluetooth for a wireless connection.
@@ -125,40 +131,67 @@ Your iPhone should be detected automatically when running the camera setup scrip
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
-1. *Install `v4l2loopback-dkms` and `v4l-utils`*. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
+1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
+
+
```python
sudo apt install v4l2loopback-dkms v4l-utils
```
-2. *Install [DroidCam](https://droidcam.app) on your phone*. This app is available for both iOS and Android.
-3. *Install [OBS Studio](https://obsproject.com)*. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
+
+
+2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
+3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
+
+
```python
flatpak install flathub com.obsproject.Studio
```
-4. *Install the DroidCam OBS plugin*. This plugin integrates DroidCam with OBS Studio. Install it with:
+
+
+4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
+
+
```python
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
```
-5. *Start OBS Studio*. Launch with:
+
+
+5. _Start OBS Studio_. Launch with:
+
+
```python
flatpak run com.obsproject.Studio
```
-6. *Add your phone as a source*. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
-7. *Adjust resolution settings*. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
-8. *Start virtual camera*. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
-9. *Verify the virtual camera setup*. Use `v4l2-ctl` to list the devices:
+
+
+6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
+7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
+8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
+9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
+
+
```python
v4l2-ctl --list-devices
```
+
+
You should see an entry like:
+
```
VirtualCam (platform:v4l2loopback-000):
/dev/video1
```
-10. *Check the camera resolution*. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
+
+10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
+
+
```python
v4l2-ctl -d /dev/video1 --get-fmt-video
```
+
+
You should see an entry like:
+
```
>>> Format Video Capture:
>>> Width/Height : 640/480
diff --git a/docs/source/debug_processor_pipeline.mdx b/docs/source/debug_processor_pipeline.mdx
new file mode 100644
index 000000000..4826c947e
--- /dev/null
+++ b/docs/source/debug_processor_pipeline.mdx
@@ -0,0 +1,299 @@
+# Debug Your Processor Pipeline
+
+Processor pipelines can be complex, especially when chaining multiple transformation steps.
+Unlike simple function calls, pipelines lack natural observability, you can't easily see what happens
+between each step or where things go wrong.
+This guide provides debugging tools and techniques specifically designed to address these challenges
+and help you understand data flow through your pipelines.
+
+We'll explore three complementary debugging approaches: **hooks** for runtime monitoring, **step-through debugging** for detailed inspection, and **feature validation** for catching structural mismatches. Each serves a different purpose and together they provide complete visibility into your pipeline's behavior.
+
+## Understanding Hooks
+
+Hooks are functions that get called at specific points during pipeline execution.
+They provide a way to inspect, monitor, or modify data without changing your pipeline code.
+Think of them as "event listeners" for your pipeline.
+
+### What is a Hook?
+
+A hook is a callback function that gets automatically invoked at specific moments during pipeline execution.
+The concept comes from event-driven programming, imagine you could "hook into" the pipeline's execution flow to observe or react to what's happening.
+
+Think of hooks like inserting checkpoints into your pipeline. Every time the pipeline reaches one of these checkpoints, it pauses briefly to call your hook function, giving you a chance to inspect the current state, log information, and validate data.
+
+A hook is simply a function that accepts two parameters:
+
+- `step_idx: int` - The index of the current processing step (0, 1, 2, etc.)
+- `transition: EnvTransition` - The data transition at that point in the pipeline
+
+The beauty of hooks is their non-invasive nature: you can add monitoring, validation, or debugging logic without changing a single line of your pipeline code. The pipeline remains clean and focused on its core logic, while hooks handle the cross-cutting concerns like logging, monitoring, and debugging.
+
+### Before vs After Hooks
+
+The pipeline supports two types of hooks:
+
+- **Before hooks** (`register_before_step_hook`) - Called before each step executes
+- **After hooks** (`register_after_step_hook`) - Called after each step completes
+
+```python
+def before_hook(step_idx: int, transition: EnvTransition):
+ """Called before step processes the transition."""
+ print(f"About to execute step {step_idx}")
+ # Useful for: logging, validation, setup
+
+def after_hook(step_idx: int, transition: EnvTransition):
+ """Called after step has processed the transition."""
+ print(f"Completed step {step_idx}")
+ # Useful for: monitoring results, cleanup, debugging
+
+processor.register_before_step_hook(before_hook)
+processor.register_after_step_hook(after_hook)
+```
+
+### Implementing a NaN Detection Hook
+
+Here's a practical example of a hook that detects NaN values:
+
+```python
+def check_nans(step_idx: int, transition: EnvTransition):
+ """Check for NaN values in observations."""
+ obs = transition.get(TransitionKey.OBSERVATION)
+ if obs:
+ for key, value in obs.items():
+ if isinstance(value, torch.Tensor) and torch.isnan(value).any():
+ print(f"NaN detected in {key} at step {step_idx}")
+
+# Register the hook to run after each step
+processor.register_after_step_hook(check_nans)
+
+# Process your data - the hook will be called automatically
+output = processor(input_data)
+
+# Remove the hook when done debugging
+processor.unregister_after_step_hook(check_nans)
+```
+
+### How Hooks Work Internally
+
+Understanding the internal mechanism helps you use hooks more effectively. The pipeline maintains two separate lists: one for before-step hooks and another for after-step hooks. When you register a hook, it's simply appended to the appropriate list.
+
+During execution, the pipeline follows a strict sequence: for each processing step, it first calls all before-hooks in registration order, then executes the actual step transformation, and finally calls all after-hooks in registration order. This creates a predictable, sandwich-like structure around each step.
+
+The key insight is that hooks don't change the core pipeline logic—they're purely additive. The pipeline's `_forward` method orchestrates this dance between hooks and processing steps, ensuring that your debugging or monitoring code runs at exactly the right moments without interfering with the main data flow.
+
+Here's a simplified view of how the pipeline executes hooks:
+
+```python
+class DataProcessorPipeline:
+ def __init__(self):
+ self.steps = [...]
+ self.before_step_hooks = [] # List of before hooks
+ self.after_step_hooks = [] # List of after hooks
+
+ def _forward(self, transition):
+ """Internal method that processes the transition through all steps."""
+ for step_idx, processor_step in enumerate(self.steps):
+ # 1. Call all BEFORE hooks
+ for hook in self.before_step_hooks:
+ hook(step_idx, transition)
+
+ # 2. Execute the actual processing step
+ transition = processor_step(transition)
+
+ # 3. Call all AFTER hooks
+ for hook in self.after_step_hooks:
+ hook(step_idx, transition)
+
+ return transition
+
+ def register_before_step_hook(self, hook_fn):
+ self.before_step_hooks.append(hook_fn)
+
+ def register_after_step_hook(self, hook_fn):
+ self.after_step_hooks.append(hook_fn)
+```
+
+### Execution Flow
+
+The execution flow looks like this:
+
+```
+Input → Before Hook → Step 0 → After Hook → Before Hook → Step 1 → After Hook → ... → Output
+```
+
+For example, with 3 steps and both hook types:
+
+```python
+def timing_before(step_idx, transition):
+ print(f"⏱️ Starting step {step_idx}")
+
+def validation_after(step_idx, transition):
+ print(f"✅ Completed step {step_idx}")
+
+processor.register_before_step_hook(timing_before)
+processor.register_after_step_hook(validation_after)
+
+# This will output:
+# ⏱️ Starting step 0
+# ✅ Completed step 0
+# ⏱️ Starting step 1
+# ✅ Completed step 1
+# ⏱️ Starting step 2
+# ✅ Completed step 2
+```
+
+### Multiple Hooks
+
+You can register multiple hooks of the same type - they execute in the order registered:
+
+```python
+def log_shapes(step_idx: int, transition: EnvTransition):
+ obs = transition.get(TransitionKey.OBSERVATION)
+ if obs:
+ print(f"Step {step_idx} observation shapes:")
+ for key, value in obs.items():
+ if isinstance(value, torch.Tensor):
+ print(f" {key}: {value.shape}")
+
+processor.register_after_step_hook(check_nans) # Executes first
+processor.register_after_step_hook(log_shapes) # Executes second
+
+# Both hooks will be called after each step in registration order
+output = processor(input_data)
+```
+
+While hooks are excellent for monitoring specific issues (like NaN detection) or gathering metrics during normal pipeline execution, sometimes you need to dive deeper. When you want to understand exactly what happens at each step or debug complex transformation logic, step-through debugging provides the detailed inspection you need.
+
+## Step-Through Debugging
+
+Step-through debugging is like having a slow-motion replay for your pipeline. Instead of watching your data get transformed in one quick blur from input to output, you can pause and examine what happens after each individual step.
+
+This approach is particularly valuable when you're trying to understand a complex pipeline, debug unexpected behavior, or verify that each transformation is working as expected. Unlike hooks, which are great for automated monitoring, step-through debugging gives you manual, interactive control over the inspection process.
+
+The `step_through()` method is a generator that yields the transition state after each processing step, allowing you to inspect intermediate results. Think of it as creating a series of snapshots of your data as it flows through the pipeline—each snapshot shows you exactly what your data looks like after one more transformation has been applied.
+
+### How Step-Through Works
+
+The `step_through()` method fundamentally changes how the pipeline executes. Instead of running all steps in sequence and only returning the final result, it transforms the pipeline into an iterator that yields intermediate results.
+
+Here's what happens internally: the method starts by converting your input data into the pipeline's internal transition format, then yields this initial state. Next, it applies the first processing step and yields the result. Then it applies the second step to that result and yields again, and so on. Each `yield` gives you a complete snapshot of the transition at that point.
+
+This generator pattern is powerful because it's lazy—the pipeline only computes the next step when you ask for it. This means you can stop at any point, inspect the current state thoroughly, and decide whether to continue. You're not forced to run the entire pipeline just to debug one problematic step.
+
+Instead of running the entire pipeline and only seeing the final result, `step_through()` pauses after each step and gives you the intermediate transition:
+
+```python
+# This creates a generator that yields intermediate states
+for i, intermediate_result in enumerate(processor.step_through(input_data)):
+ print(f"=== After step {i} ===")
+
+ # Inspect the observation at this stage
+ obs = intermediate_result.get(TransitionKey.OBSERVATION)
+ if obs:
+ for key, value in obs.items():
+ if isinstance(value, torch.Tensor):
+ print(f"{key}: shape={value.shape}, dtype={value.dtype}")
+```
+
+### Interactive Debugging with Breakpoints
+
+You can add breakpoints in the step-through loop to interactively debug:
+
+```python
+# Step through the pipeline with debugging
+for i, intermediate in enumerate(processor.step_through(data)):
+ print(f"Step {i}: {processor.steps[i].__class__.__name__}")
+
+ # Set a breakpoint to inspect the current state
+ breakpoint() # Debugger will pause here
+
+ # You can now inspect 'intermediate' in the debugger:
+ # - Check tensor shapes and values
+ # - Verify expected transformations
+ # - Look for unexpected changes
+```
+
+During the debugger session, you can:
+
+- Examine `intermediate[TransitionKey.OBSERVATION]` to see observation data
+- Check `intermediate[TransitionKey.ACTION]` for action transformations
+- Inspect any part of the transition to understand what each step does
+
+Step-through debugging is perfect for understanding the _data_ transformations, but what about the _structure_ of that data? While hooks and step-through help you debug runtime behavior, you also need to ensure your pipeline produces data in the format expected by downstream components. This is where feature contract validation comes in.
+
+## Validating Feature Contracts
+
+Feature contracts define what data structure your pipeline expects as input and produces as output.
+Validating these contracts helps catch mismatches early.
+
+### Understanding Feature Contracts
+
+Each processor step has a `transform_features()` method that describes how it changes the data structure:
+
+```python
+# Get the expected output features from your pipeline
+initial_features = {
+ PipelineFeatureType.OBSERVATION: {
+ "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(7,)),
+ "observation.image": PolicyFeature(type=FeatureType.IMAGE, shape=(3, 224, 224))
+ },
+ PipelineFeatureType.ACTION: {
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,))
+ }
+}
+
+# Check what your pipeline will output
+output_features = processor.transform_features(initial_features)
+
+print("Input features:")
+for feature_type, features in initial_features.items():
+ print(f" {feature_type}:")
+ for key, feature in features.items():
+ print(f" {key}: {feature.type.value}, shape={feature.shape}")
+
+print("\nOutput features:")
+for feature_type, features in output_features.items():
+ print(f" {feature_type}:")
+ for key, feature in features.items():
+ print(f" {key}: {feature.type.value}, shape={feature.shape}")
+```
+
+### Verifying Expected Features
+
+Check that your pipeline produces the features you expect:
+
+```python
+# Define what features you expect the pipeline to produce
+expected_keys = ["observation.state", "observation.image", "action"]
+
+print("Validating feature contract...")
+for expected_key in expected_keys:
+ found = False
+ for feature_type, features in output_features.items():
+ if expected_key in features:
+ feature = features[expected_key]
+ print(f"✅ {expected_key}: {feature.type.value}, shape={feature.shape}")
+ found = True
+ break
+
+ if not found:
+ print(f"❌ Missing expected feature: {expected_key}")
+```
+
+This validation helps ensure your pipeline will work correctly with downstream components that expect specific data structures.
+
+## Summary
+
+Now that you understand the three debugging approaches, you can tackle any pipeline issue systematically:
+
+1. **Hooks** - For runtime monitoring and validation without modifying pipeline code
+2. **Step-through** - For inspecting intermediate states and understanding transformations
+3. **Feature validation** - For ensuring data structure contracts are met
+
+**When to use each approach:**
+
+- Start with **step-through debugging** when you need to understand what your pipeline does or when something unexpected happens
+- Add **hooks** for continuous monitoring during development and production to catch issues automatically
+- Use **feature validation** before deployment to ensure your pipeline works with downstream components
+
+These three tools work together to give you the complete observability that complex pipelines naturally lack. With hooks watching for issues, step-through helping you understand behavior, and feature validation ensuring compatibility, you'll be able to debug any pipeline confidently and efficiently.
diff --git a/docs/source/feetech.mdx b/docs/source/feetech.mdx
new file mode 100644
index 000000000..bba60e4cc
--- /dev/null
+++ b/docs/source/feetech.mdx
@@ -0,0 +1,71 @@
+# Feetech Motor Firmware Update
+
+This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software.
+
+## Prerequisites
+
+- Windows computer (Feetech software is only available for Windows)
+- Feetech motor control board
+- USB cable to connect the control board to your computer
+- Feetech motors connected to the control board
+
+## Step 1: Download Feetech Software
+
+1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html)
+2. Download the latest version of the Feetech debugging software (FD)
+3. Install the software on your Windows computer
+
+## Step 2: Hardware Setup
+
+1. Connect your Feetech motors to the motor control board
+2. Connect the motor control board to your Windows computer via USB cable
+3. Ensure power is supplied to the motors
+
+## Step 3: Configure Connection
+
+1. Launch the Feetech debugging software
+2. Select the correct COM port from the port dropdown menu
+ - If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)"
+3. Set the appropriate baud rate (typically 1000000 for most Feetech motors)
+4. Click "Open" to establish communication with the control board
+
+## Step 4: Scan for Motors
+
+1. Once connected, click the "Search" button to detect all connected motors
+2. The software will automatically discover and list all motors on the bus
+3. Each motor will appear with its ID number
+
+## Step 5: Update Firmware
+
+For each motor you want to update:
+
+1. **Select the motor** from the list by clicking on it
+2. **Click on Upgrade tab**:
+3. **Click on Online button**:
+ - If an potential firmware update is found, it will be displayed in the box
+4. **Click on Upgrade button**:
+ - The update progress will be displayed
+
+## Step 6: Verify Update
+
+1. After the update completes, the software should automatically refresh the motor information
+2. Verify that the firmware version has been updated to the expected version
+
+## Important Notes
+
+⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor.
+
+## Bonus: Motor Debugging on Linux/macOS
+
+For debugging purposes only, you can use the open-source Feetech Debug Tool:
+
+- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer)
+
+### Installation Instructions
+
+Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source.
+
+**Limitations:**
+
+- This tool is for debugging and parameter adjustment only
+- Firmware updates must still be done on Windows with official Feetech software
diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx
index b3ab40c89..ad1c74f9a 100644
--- a/docs/source/hilserl.mdx
+++ b/docs/source/hilserl.mdx
@@ -5,17 +5,27 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
It combines three key ingredients:
- 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
- 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
- 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
+
+1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
+
+2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
+
+3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
-
+
-
HIL-SERL workflow, Luo et al. 2024
+
+ HIL-SERL workflow, Luo et al. 2024
+
This guide provides step-by-step instructions for training a robot policy using LeRobot's HilSerl implementation to train on a real robot.
@@ -24,11 +34,12 @@ This guide provides step-by-step instructions for training a robot policy using
- A gamepad (recommended) or keyboard to control the robot
- A Nvidia GPU
- A real robot with a follower and leader arm (optional if you use the keyboard or the gamepad)
-- A URDF file for the robot for the kinematics package (check `lerobot/common/model/kinematics.py`)
+- A URDF file for the robot for the kinematics package (check `lerobot/model/kinematics.py`)
## What kind of tasks can I train?
One can use HIL-SERL to train on a variety of manipulation tasks. Some recommendations:
+
- Start with a simple task to understand how the system works.
- Push cube to a goal region
- Pick and lift cube with the gripper
@@ -51,28 +62,238 @@ pip install -e ".[hilserl]"
### Understanding Configuration
-The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
+The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
+
```python
+class GymManipulatorConfig:
+ env: HILSerlRobotEnvConfig # Environment configuration (nested)
+ dataset: DatasetConfig # Dataset recording/replay configuration (nested)
+ mode: str | None = None # "record", "replay", or None (for training)
+ device: str = "cpu" # Compute device
+
class HILSerlRobotEnvConfig(EnvConfig):
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
- teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
- wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
- fps: int = 10 # Control frequency
+ teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
+ processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
name: str = "real_robot" # Environment name
- mode: str = None # "record", "replay", or None (for training)
- repo_id: str | None = None # LeRobot dataset repository ID
- dataset_root: str | None = None # Local dataset root (optional)
- task: str = "" # Task identifier
- num_episodes: int = 10 # Number of episodes for recording
- episode: int = 0 # episode index for replay
- device: str = "cuda" # Compute device
- push_to_hub: bool = True # Whether to push the recorded datasets to Hub
- pretrained_policy_name_or_path: str | None = None # For policy loading
- reward_classifier_pretrained_path: str | None = None # For reward model
- number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
+ task: str | None = None # Task identifier
+ fps: int = 10 # Control frequency
+
+# Nested processor configuration
+class HILSerlProcessorConfig:
+ control_mode: str = "gamepad" # Control mode
+ observation: ObservationConfig | None = None # Observation processing settings
+ image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
+ gripper: GripperConfig | None = None # Gripper control and penalty settings
+ reset: ResetConfig | None = None # Environment reset and timing settings
+ inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
+ reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
+ max_gripper_pos: float | None = 100.0 # Maximum gripper position
+
+# Sub-configuration classes
+class ObservationConfig:
+ add_joint_velocity_to_observation: bool = False # Add joint velocities to state
+ add_current_to_observation: bool = False # Add motor currents to state
+ display_cameras: bool = False # Display camera feeds during execution
+
+class ImagePreprocessingConfig:
+ crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
+ resize_size: tuple[int, int] | None = None # Target image size
+
+class GripperConfig:
+ use_gripper: bool = True # Enable gripper control
+ gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
+
+class ResetConfig:
+ fixed_reset_joint_positions: Any | None = None # Joint positions for reset
+ reset_time_s: float = 5.0 # Time to wait during reset
+ control_time_s: float = 20.0 # Maximum episode duration
+ terminate_on_success: bool = True # Whether to terminate episodes on success detection
+
+class InverseKinematicsConfig:
+ urdf_path: str | None = None # Path to robot URDF file
+ target_frame_name: str | None = None # End-effector frame name
+ end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
+ end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
+
+class RewardClassifierConfig:
+ pretrained_path: str | None = None # Path to pretrained reward classifier
+ success_threshold: float = 0.5 # Success detection threshold
+ success_reward: float = 1.0 # Reward value for successful episodes
+
+# Dataset configuration
+class DatasetConfig:
+ repo_id: str # LeRobot dataset repository ID
+ task: str # Task identifier
+ root: str | None = None # Local dataset root directory
+ num_episodes_to_record: int = 5 # Number of episodes for recording
+ replay_episode: int | None = None # Episode index for replay
+ push_to_hub: bool = False # Whether to push datasets to Hub
+```
+
+
+### Processor Pipeline Architecture
+
+HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
+
+#### Environment Processor Pipeline
+
+The environment processor (`env_processor`) handles incoming observations and environment state:
+
+1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format
+2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations
+3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations
+4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
+5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images
+6. **TimeLimitProcessorStep** (optional): Enforces episode time limits
+7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage
+8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models
+9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing
+10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU)
+
+#### Action Processor Pipeline
+
+The action processor (`action_processor`) handles outgoing actions and human interventions:
+
+1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging
+2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals
+3. **InterventionActionProcessorStep**: Handles human interventions and episode termination
+4. **Inverse Kinematics Pipeline** (when enabled):
+ - **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format
+ - **EEReferenceAndDelta**: Computes end-effector reference and delta movements
+ - **EEBoundsAndSafety**: Enforces workspace safety bounds
+ - **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
+ - **GripperVelocityToJoint**: Handles gripper control commands
+
+#### Configuration Examples
+
+**Basic Observation Processing**:
+
+```json
+{
+ "env": {
+ "processor": {
+ "observation": {
+ "add_joint_velocity_to_observation": true,
+ "add_current_to_observation": false,
+ "display_cameras": false
+ }
+ }
+ }
+}
```
+**Image Processing**:
+
+```json
+{
+ "env": {
+ "processor": {
+ "image_preprocessing": {
+ "crop_params_dict": {
+ "observation.images.front": [180, 250, 120, 150],
+ "observation.images.side": [180, 207, 180, 200]
+ },
+ "resize_size": [128, 128]
+ }
+ }
+ }
+}
+```
+
+**Inverse Kinematics Setup**:
+
+```json
+{
+ "env": {
+ "processor": {
+ "inverse_kinematics": {
+ "urdf_path": "path/to/robot.urdf",
+ "target_frame_name": "end_effector",
+ "end_effector_bounds": {
+ "min": [0.16, -0.08, 0.03],
+ "max": [0.24, 0.2, 0.1]
+ },
+ "end_effector_step_sizes": {
+ "x": 0.02,
+ "y": 0.02,
+ "z": 0.02
+ }
+ }
+ }
+ }
+}
+```
+
+### Advanced Observation Processing
+
+The HIL-SERL framework supports additional observation processing features that can improve policy learning:
+
+#### Joint Velocity Processing
+
+Enable joint velocity estimation to provide the policy with motion information:
+
+```json
+{
+ "env": {
+ "processor": {
+ "observation": {
+ "add_joint_velocity_to_observation": true
+ }
+ }
+ }
+}
+```
+
+This processor:
+
+- Estimates joint velocities using finite differences between consecutive joint position readings
+- Adds velocity information to the observation state vector
+- Useful for policies that need motion awareness for dynamic tasks
+
+#### Motor Current Processing
+
+Monitor motor currents to detect contact forces and load conditions:
+
+```json
+{
+ "env": {
+ "processor": {
+ "observation": {
+ "add_current_to_observation": true
+ }
+ }
+ }
+}
+```
+
+This processor:
+
+- Reads motor current values from the robot's control system
+- Adds current measurements to the observation state vector
+- Helps detect contact events, object weights, and mechanical resistance
+- Useful for contact-rich manipulation tasks
+
+#### Combined Observation Processing
+
+You can enable multiple observation processing features simultaneously:
+
+```json
+{
+ "env": {
+ "processor": {
+ "observation": {
+ "add_joint_velocity_to_observation": true,
+ "add_current_to_observation": true,
+ "display_cameras": false
+ }
+ }
+ }
+}
+```
+
+**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
### Finding Robot Workspace Bounds
@@ -80,19 +301,19 @@ Before collecting demonstrations, you need to determine the appropriate operatio
This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates.
-**Using find_joint_limits.py**
+**Using lerobot-find-joint-limits**
This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training.
Bounding the action space will reduce the redundant exploration of the agent and guarantees safety.
```bash
-python -m lerobot.scripts.find_joint_limits \
- --robot.type=so100_follower \
- --robot.port=/dev/tty.usbmodem58760431541 \
- --robot.id=black \
- --teleop.type=so100_leader \
- --teleop.port=/dev/tty.usbmodem58760431551 \
- --teleop.id=blue
+lerobot-find-joint-limits \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58760431541 \
+ --robot.id=black \
+ --teleop.type=so100_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \
+ --teleop.id=blue
```
**Workflow**
@@ -122,23 +343,58 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
**Setting Up Record Mode**
-Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
+Create a configuration file for recording demonstrations (or edit an existing one like [env_config.json](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/env_config.json)):
-1. Set `mode` to `"record"`
-2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
-3. Set `num_episodes` to the number of demonstrations you want to collect
-4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
-5. Configure `robot`, `cameras`, and other hardware settings
+1. Set `mode` to `"record"` at the root level
+2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
+3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect
+4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
+5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
Example configuration section:
+
```json
-"mode": "record",
-"repo_id": "username/pick_lift_cube",
-"dataset_root": null,
-"task": "pick_and_lift",
-"num_episodes": 15,
-"episode": 0,
-"push_to_hub": true
+{
+ "env": {
+ "type": "gym_manipulator",
+ "name": "real_robot",
+ "fps": 10,
+ "processor": {
+ "control_mode": "gamepad",
+ "observation": {
+ "display_cameras": false
+ },
+ "image_preprocessing": {
+ "crop_params_dict": {},
+ "resize_size": [128, 128]
+ },
+ "gripper": {
+ "use_gripper": true,
+ "gripper_penalty": 0.0
+ },
+ "reset": {
+ "reset_time_s": 5.0,
+ "control_time_s": 20.0
+ }
+ },
+ "robot": {
+ // ... robot configuration ...
+ },
+ "teleop": {
+ // ... teleoperator configuration ...
+ }
+ },
+ "dataset": {
+ "repo_id": "username/pick_lift_cube",
+ "root": null,
+ "task": "pick_and_lift",
+ "num_episodes_to_record": 15,
+ "replay_episode": 0,
+ "push_to_hub": true
+ },
+ "mode": "record",
+ "device": "cpu"
+}
```
### Using a Teleoperation Device
@@ -150,6 +406,7 @@ HIL-Serl learns actions in the end-effector space of the robot. Therefore, the t
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
+
```python
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
"""Configuration for the SO100FollowerEndEffector robot."""
@@ -172,6 +429,7 @@ class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
}
)
```
+
The `Teleoperator` defines the teleoperation device. You can check the list of available teleoperators in `lerobot/teleoperators`.
@@ -182,16 +440,33 @@ The gamepad provides a very convenient way to control the robot and the episode
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
```json
+{
+ "env": {
"teleop": {
- "type": "gamepad",
- "use_gripper": true
+ "type": "gamepad",
+ "use_gripper": true
},
+ "processor": {
+ "control_mode": "gamepad",
+ "gripper": {
+ "use_gripper": true
+ }
+ }
+ }
+}
```
-
+
+
+
+ Gamepad button mapping for robot control and episode management
-
Gamepad button mapping for robot control and episode management
**Setting up the SO101 leader**
@@ -200,11 +475,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
```json
+{
+ "env": {
"teleop": {
- "type": "so101_leader",
- "port": "/dev/tty.usbmodem585A0077921", # check your port number
- "use_degrees": true
+ "type": "so101_leader",
+ "port": "/dev/tty.usbmodem585A0077921",
+ "use_degrees": true
},
+ "processor": {
+ "control_mode": "leader",
+ "gripper": {
+ "use_gripper": true
+ }
+ }
+ }
+}
```
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
@@ -215,7 +500,10 @@ During the online training, press `space` to take over the policy and `space` ag
@@ -227,11 +515,12 @@ During the online training, press `space` to take over the policy and `space` ag
Start the recording process, an example of the config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json):
```bash
-python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/env_config_so100.json
+python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/env_config_so100.json
```
During recording:
-1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
+
+1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
2. Complete the task successfully
3. The episode ends with a reward of 1 when you press the "success" button
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
@@ -239,13 +528,13 @@ During recording:
6. The process automatically continues to the next episode
7. After recording all episodes, the dataset is pushed to the Hugging Face Hub (optional) and saved locally
-
### Processing the Dataset
After collecting demonstrations, process them to determine optimal camera crops.
Reinforcement learning is sensitive to background distractions, so it is important to crop the images to the relevant workspace area.
Visual RL algorithms learn directly from pixel inputs, making them vulnerable to irrelevant visual information. Background elements like changing lighting, shadows, people moving, or objects outside the workspace can confuse the learning process. Good ROI selection should:
+
- Include only the essential workspace where the task happens
- Capture the robot's end-effector and all objects involved in the task
- Exclude unnecessary background elements and distractions
@@ -257,7 +546,7 @@ Note: If you already know the crop parameters, you can skip this step and just s
Use the `crop_dataset_roi.py` script to interactively select regions of interest in your camera images:
```bash
-python -m lerobot.scripts.rl.crop_dataset_roi --repo-id username/pick_lift_cube
+python -m lerobot.rl.crop_dataset_roi --repo-id username/pick_lift_cube
```
1. For each camera view, the script will display the first frame
@@ -267,6 +556,7 @@ python -m lerobot.scripts.rl.crop_dataset_roi --repo-id username/pick_lift_cube
5. The script outputs cropping parameters and creates a new cropped dataset
Example output:
+
```
Selected Rectangular Regions of Interest (top, left, height, width):
observation.images.side: [180, 207, 180, 200]
@@ -274,28 +564,39 @@ observation.images.front: [180, 250, 120, 150]
```
-
+
-
Interactive cropping tool for selecting regions of interest
-
+
+ Interactive cropping tool for selecting regions of interest
+
**Updating Configuration**
Add these crop parameters to your training configuration:
```json
-"crop_params_dict": {
- "observation.images.side": [180, 207, 180, 200],
- "observation.images.front": [180, 250, 120, 150]
-},
-"resize_size": [128, 128]
+{
+ "env": {
+ "processor": {
+ "image_preprocessing": {
+ "crop_params_dict": {
+ "observation.images.side": [180, 207, 180, 200],
+ "observation.images.front": [180, 250, 120, 150]
+ },
+ "resize_size": [128, 128]
+ }
+ }
+ }
+}
```
**Recommended image resolution**
-Most vision-based policies have been validated on square inputs of either **128×128** (default) or **64×64** pixels. We therefore advise setting the resize_size parameter to [128, 128] – or [64, 64] if you need to save GPU memory and bandwidth. Other resolutions are possible but have not been extensively tested.
-
+Most vision-based policies have been validated on square inputs of either **128×128** (default) or **64×64** pixels. We therefore advise setting the resize_size parameter to [128, 128] – or [64, 64] if you need to save GPU memory and bandwidth. Other resolutions are possible but have not been extensively tested.
### Training a Reward Classifier
@@ -314,31 +615,57 @@ Before training, you need to collect a dataset with labeled examples. The `recor
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
```bash
-python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/reward_classifier_train_config.json
+python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/reward_classifier_train_config.json
```
**Key Parameters for Data Collection**
-- **mode**: set it to `"record"` to collect a dataset
-- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
-- **num_episodes**: Number of episodes to record
-- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
-- **fps**: Number of frames per second to record
-- **push_to_hub**: Whether to push the dataset to the hub
+- **mode**: set it to `"record"` to collect a dataset (at root level)
+- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
+- **dataset.num_episodes_to_record**: Number of episodes to record
+- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
+- **env.fps**: Number of frames per second to record
+- **dataset.push_to_hub**: Whether to push the dataset to the hub
-The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
+The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
+
+**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
Example configuration section for data collection:
```json
{
- "mode": "record",
+ "env": {
+ "type": "gym_manipulator",
+ "name": "real_robot",
+ "fps": 10,
+ "processor": {
+ "reset": {
+ "reset_time_s": 5.0,
+ "control_time_s": 20.0,
+ "terminate_on_success": false
+ },
+ "gripper": {
+ "use_gripper": true
+ }
+ },
+ "robot": {
+ // ... robot configuration ...
+ },
+ "teleop": {
+ // ... teleoperator configuration ...
+ }
+ },
+ "dataset": {
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
- "num_episodes": 20,
- "push_to_hub": true,
- "fps": 10,
- "number_of_steps_after_success": 15
+ "task": "reward_classifier_task",
+ "num_episodes_to_record": 20,
+ "replay_episode": null,
+ "push_to_hub": true
+ },
+ "mode": "record",
+ "device": "cpu"
}
```
@@ -388,30 +715,53 @@ Example configuration for training the [reward classifier](https://huggingface.c
To train the classifier, use the `train.py` script with your configuration:
```bash
-python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_config.json
+lerobot-train --config_path path/to/reward_classifier_train_config.json
```
**Deploying and Testing the Model**
To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to use your model:
+
```python
-env_config = HILSerlRobotEnvConfig(
- reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
- # Other environment parameters
+config = GymManipulatorConfig(
+ env=HILSerlRobotEnvConfig(
+ processor=HILSerlProcessorConfig(
+ reward_classifier=RewardClassifierConfig(
+ pretrained_path="path_to_your_pretrained_trained_model"
+ )
+ ),
+ # Other environment parameters
+ ),
+ dataset=DatasetConfig(...),
+ mode=None # For training
)
```
+
+
or set the argument in the json config file.
```json
{
- "reward_classifier_pretrained_path": "path_to_your_pretrained_model"
+ "env": {
+ "processor": {
+ "reward_classifier": {
+ "pretrained_path": "path_to_your_pretrained_model",
+ "success_threshold": 0.7,
+ "success_reward": 1.0
+ },
+ "reset": {
+ "terminate_on_success": true
+ }
+ }
+ }
}
```
Run `gym_manipulator.py` to test the model.
+
```bash
-python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/env_config.json
+python -m lerobot.rl.gym_manipulator --config_path path/to/env_config.json
```
The reward classifier will automatically provide rewards based on the visual input from the robot's cameras.
@@ -419,21 +769,23 @@ The reward classifier will automatically provide rewards based on the visual inp
**Example Workflow for training the reward classifier**
1. **Create the configuration files**:
- Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/tree/main).
+ Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/reward_classifier/config.json).
2. **Collect a dataset**:
+
```bash
- python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json
+ python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json
```
3. **Train the classifier**:
+
```bash
- python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json
+ lerobot-train --config_path src/lerobot/configs/reward_classifier_train_config.json
```
4. **Test the classifier**:
```bash
- python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json
+ python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json
```
### Training with Actor-Learner
@@ -442,12 +794,12 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
**Configuration Setup**
-Create a training configuration file (example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_config_hilserl_so100.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
+Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
1. Configure the policy settings (`type="sac"`, `device`, etc.)
2. Set `dataset` to your cropped dataset
3. Configure environment settings with crop parameters
-4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/policies/sac/configuration_sac.py#L79).
+4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
**Starting the Learner**
@@ -455,10 +807,11 @@ Create a training configuration file (example available [here](https://huggingfa
First, start the learner server process:
```bash
-python -m lerobot.scripts.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json
+python -m lerobot.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json
```
The learner:
+
- Initializes the policy network
- Prepares replay buffers
- Opens a `gRPC` server to communicate with actors
@@ -469,10 +822,11 @@ The learner:
In a separate terminal, start the actor process with the same configuration:
```bash
-python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
+python -m lerobot.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
```
The actor:
+
- Connects to the learner via `gRPC`
- Initializes the environment
- Execute rollouts of the policy to collect experience
@@ -496,10 +850,19 @@ The training proceeds automatically:
- A successful experiment is one where the human has to intervene at the start but then reduces the amount of interventions as the policy improves. You can monitor the intervention rate in the `wandb` dashboard.
-
+
-
Example showing how human interventions help guide policy learning over time
+
+
+ Example showing how human interventions help guide policy learning over time
+
+
- The figure shows the plot of the episodic reward over interaction step. The figure shows the effect of human interventions on the policy learning.
- The orange curve is an experiment without any human interventions. While the pink and blue curves are experiments with human interventions.
@@ -510,7 +873,9 @@ The training proceeds automatically:
If you have `wandb.enable` set to `true` in your configuration, you can monitor training progress in real-time through the [Weights & Biases](https://wandb.ai/site/) dashboard.
### Guide to Human Interventions
+
The learning process is very sensitive to the intervention strategy. It will takes a few runs to understand how to intervene effectively. Some tips and hints:
+
- Allow the policy to explore for a few episodes at the start of training.
- Avoid intervening for long periods of time. Try to intervene in situation to correct the robot's behaviour when it goes off track.
- Once the policy starts achieving the task, even if its not perfect, you can limit your interventions to simple quick actions like a simple grasping commands.
@@ -518,26 +883,36 @@ The learning process is very sensitive to the intervention strategy. It will tak
The ideal behaviour is that your intervention rate should drop gradually during training as shown in the figure below.
-
+
-
Plot of the intervention rate during a training run on a pick and lift cube task
+
+
+ Plot of the intervention rate during a training run on a pick and lift cube
+ task
+
+
### Key hyperparameters to tune
Some configuration values have a disproportionate impact on training stability and speed:
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
-- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in *seconds* between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
+- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
-
Congrats 🎉, you have finished this tutorial!
> [!TIP]
-> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
+> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
Paper citation:
+
```
@article{luo2024precise,
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx
index ad7a9584a..e2dddd9ed 100644
--- a/docs/source/hilserl_sim.mdx
+++ b/docs/source/hilserl_sim.mdx
@@ -11,7 +11,6 @@ This guide explains how to use the `gym_hil` simulation environments as an alter
Currently, the main environment is a Franka Panda robot simulation based on MuJoCo, with tasks like picking up a cube.
-
## Installation
First, install the `gym_hil` package within the LeRobot environment:
@@ -25,49 +24,64 @@ pip install -e ".[hilserl]"
- A gamepad or keyboard to control the robot
- A Nvidia GPU
-
-
## Configuration
-To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/gym_hil_env.json). Key configuration sections include:
+To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/gym_hil/env_config.json). Key configuration sections include:
### Environment Type and Task
```json
{
- "type": "hil",
- "name": "franka_sim",
+ "env": {
+ "type": "gym_manipulator",
+ "name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
- "device": "cuda"
+ "fps": 10
+ },
+ "device": "cuda"
}
```
Available tasks:
+
- `PandaPickCubeBase-v0`: Basic environment
- `PandaPickCubeGamepad-v0`: With gamepad control
- `PandaPickCubeKeyboard-v0`: With keyboard control
-### Gym Wrappers Configuration
+### Processor Configuration
```json
-"wrapper": {
- "gripper_penalty": -0.02,
- "control_time_s": 15.0,
- "use_gripper": true,
- "fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
- "end_effector_step_sizes": {
- "x": 0.025,
- "y": 0.025,
- "z": 0.025
- },
- "control_mode": "gamepad"
+{
+ "env": {
+ "processor": {
+ "control_mode": "gamepad",
+ "gripper": {
+ "use_gripper": true,
+ "gripper_penalty": -0.02
+ },
+ "reset": {
+ "control_time_s": 15.0,
+ "fixed_reset_joint_positions": [
+ 0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
+ ]
+ },
+ "inverse_kinematics": {
+ "end_effector_step_sizes": {
+ "x": 0.025,
+ "y": 0.025,
+ "z": 0.025
+ }
+ }
}
+ }
+}
```
Important parameters:
-- `gripper_penalty`: Penalty for excessive gripper movement
-- `use_gripper`: Whether to enable gripper control
-- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
+
+- `gripper.gripper_penalty`: Penalty for excessive gripper movement
+- `gripper.use_gripper`: Whether to enable gripper control
+- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
## Running with HIL RL of LeRobot
@@ -76,30 +90,49 @@ Important parameters:
To run the environment, set mode to null:
-```python
-python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
+```bash
+python -m lerobot.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
### Recording a Dataset
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
-```python
-python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
+```json
+{
+ "env": {
+ "type": "gym_manipulator",
+ "name": "gym_hil",
+ "task": "PandaPickCubeGamepad-v0"
+ },
+ "dataset": {
+ "repo_id": "username/sim_dataset",
+ "root": null,
+ "task": "pick_cube",
+ "num_episodes_to_record": 10,
+ "replay_episode": null,
+ "push_to_hub": true
+ },
+ "mode": "record"
+}
+```
+
+```bash
+python -m lerobot.rl.gym_manipulator --config_path path/to/gym_hil_env.json
```
### Training a Policy
-To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
+To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/gym_hil/train_config.json) and run the actor and learner servers:
-```python
-python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
+```bash
+python -m lerobot.rl.actor --config_path path/to/train_gym_hil_env.json
```
In a different terminal, run the learner server:
-```python
-python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
+```bash
+python -m lerobot.rl.learner --config_path path/to/train_gym_hil_env.json
```
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
@@ -107,9 +140,10 @@ The simulation environment provides a safe and repeatable way to develop and tes
Congrats 🎉, you have finished this tutorial!
> [!TIP]
-> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
+> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
Paper citation:
+
```
@article{luo2024precise,
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx
new file mode 100644
index 000000000..856febb95
--- /dev/null
+++ b/docs/source/hope_jr.mdx
@@ -0,0 +1,277 @@
+# HopeJR
+
+## Prerequisites
+
+- [Hardware Setup](https://github.com/TheRobotStudio/HOPEJr)
+
+## Install LeRobot
+
+Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
+
+Install LeRobot with HopeJR dependencies:
+
+```bash
+pip install -e ".[hopejr]"
+```
+
+## Device Configuration
+
+Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton:
+
+```bash
+lerobot-find-port
+```
+
+This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts.
+
+## Step 1: Calibration
+
+Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibration files will be saved in `~/.cache/huggingface/lerobot/calibration`
+
+### 1.1 Calibrate Robot Hand
+
+```bash
+lerobot-calibrate \
+ --robot.type=hope_jr_hand \
+ --robot.port=/dev/tty.usbmodem58760432281 \
+ --robot.id=blue \
+ --robot.side=right
+```
+
+When running the calibration script, a calibration GUI will pop up. Finger joints are named as follows:
+
+**Thumb**:
+
+- **CMC**: base joint connecting thumb to hand
+- **MCP**: knuckle joint
+- **PIP**: first finger joint
+- **DIP** : fingertip joint
+
+**Index, Middle, Ring, and Pinky fingers**:
+
+- **Radial flexor**: Moves base of finger towards the thumb
+- **Ulnar flexor**: Moves base of finger towards the pinky
+- **PIP/DIP**: Flexes the distal and proximal phalanx of the finger
+
+Each one of these will need to be calibrated individually via the GUI.
+Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement.
+
+
+
+
+
+Use the calibration interface to set the range boundaries for each joint as shown above.
+
+
+
+
+
+Once you have set the appropriate boundaries for all joints, click "Save" to save the calibration values to the motors.
+
+### 1.2 Calibrate Teleoperator Glove
+
+```bash
+lerobot-calibrate \
+ --teleop.type=homunculus_glove \
+ --teleop.port=/dev/tty.usbmodem11201 \
+ --teleop.id=red \
+ --teleop.side=right
+```
+
+Move each finger through its full range of motion, starting from the thumb.
+
+```
+Move thumb through its entire range of motion.
+Recording positions. Press ENTER to stop...
+
+-------------------------------------------
+NAME | MIN | POS | MAX
+thumb_cmc | 1790 | 1831 | 1853
+thumb_mcp | 1497 | 1514 | 1528
+thumb_pip | 1466 | 1496 | 1515
+thumb_dip | 1463 | 1484 | 1514
+```
+
+Continue with each finger:
+
+```
+Move middle through its entire range of motion.
+Recording positions. Press ENTER to stop...
+
+-------------------------------------------
+NAME | MIN | POS | MAX
+middle_mcp_abduction | 1598 | 1718 | 1820
+middle_mcp_flexion | 1512 | 1658 | 2136
+middle_dip | 1484 | 1500 | 1547
+```
+
+Once calibration is complete, the system will save the calibration to `/Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_glove/red.json`
+
+### 1.3 Calibrate Robot Arm
+
+```bash
+lerobot-calibrate \
+ --robot.type=hope_jr_arm \
+ --robot.port=/dev/tty.usbserial-1110 \
+ --robot.id=white
+```
+
+This will open a calibration GUI where you can set the range limits for each motor. The arm motions are organized as follows:
+
+- **Shoulder**: pitch, yaw, and roll
+- **Elbow**: flex
+- **Wrist**: pitch, yaw, and roll
+
+
+
+
+
+Use the calibration interface to set the range boundaries for each joint. Move each joint through its full range of motion and adjust the minimum and maximum values accordingly. Once you have set the appropriate boundaries for all joints, save the calibration.
+
+### 1.4 Calibrate Teleoperator Exoskeleton
+
+```bash
+lerobot-calibrate \
+ --teleop.type=homunculus_arm \
+ --teleop.port=/dev/tty.usbmodem11201 \
+ --teleop.id=black
+```
+
+The exoskeleton allows one to control the robot arm. During calibration, you'll be prompted to move all joints through their full range of motion:
+
+```
+Move all joints through their entire range of motion.
+Recording positions. Press ENTER to stop...
+
+-------------------------------------------
+-------------------------------------------
+NAME | MIN | POS | MAX
+shoulder_pitch | 586 | 736 | 895
+shoulder_yaw | 1257 | 1374 | 1390
+shoulder_roll | 449 | 1034 | 2564
+elbow_flex | 3023 | 3117 | 3134
+wrist_roll | 3073 | 3096 | 3147
+wrist_yaw | 2143 | 2171 | 2185
+wrist_pitch | 1975 | 1993 | 2074
+Calibration saved to /Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_arm/black.json
+```
+
+## Step 2: Teleoperation
+
+Due to global variable conflicts in the Feetech middleware, teleoperation for arm and hand must run in separate shell sessions:
+
+### Hand
+
+```bash
+lerobot-teleoperate \
+ --robot.type=hope_jr_hand \
+ --robot.port=/dev/tty.usbmodem58760432281 \
+ --robot.id=blue \
+ --robot.side=right \
+ --teleop.type=homunculus_glove \
+ --teleop.port=/dev/tty.usbmodem11201 \
+ --teleop.id=red \
+ --teleop.side=right \
+ --display_data=true \
+ --fps=30
+```
+
+### Arm
+
+```bash
+lerobot-teleoperate \
+ --robot.type=hope_jr_arm \
+ --robot.port=/dev/tty.usbserial-1110 \
+ --robot.id=white \
+ --teleop.type=homunculus_arm \
+ --teleop.port=/dev/tty.usbmodem11201 \
+ --teleop.id=black \
+ --display_data=true \
+ --fps=30
+```
+
+## Step 3: Record, Replay, Train
+
+Record, Replay and Train with Hope-JR is still experimental.
+
+### Record
+
+This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings).
+
+```bash
+lerobot-record \
+ --robot.type=hope_jr_hand \
+ --robot.port=/dev/tty.usbmodem58760432281 \
+ --robot.id=right \
+ --robot.side=right \
+ --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
+ --teleop.type=homunculus_glove \
+ --teleop.port=/dev/tty.usbmodem1201 \
+ --teleop.id=right \
+ --teleop.side=right \
+ --dataset.repo_id=nepyope/hand_record_test_with_video_data \
+ --dataset.single_task="Hand recording test with video data" \
+ --dataset.num_episodes=1 \
+ --dataset.episode_time_s=5 \
+ --dataset.push_to_hub=true \
+ --dataset.private=true \
+ --display_data=true
+```
+
+### Replay
+
+```bash
+lerobot-replay \
+ --robot.type=hope_jr_hand \
+ --robot.port=/dev/tty.usbmodem58760432281 \
+ --robot.id=right \
+ --robot.side=right \
+ --dataset.repo_id=nepyope/hand_record_test_with_camera \
+ --dataset.episode=0
+```
+
+### Train
+
+```bash
+lerobot-train \
+ --dataset.repo_id=nepyope/hand_record_test_with_video_data \
+ --policy.type=act \
+ --output_dir=outputs/train/hopejr_hand \
+ --job_name=hopejr \
+ --policy.device=mps \
+ --wandb.enable=true \
+ --policy.repo_id=nepyope/hand_test_policy
+```
+
+### Evaluate
+
+This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino).
+
+```bash
+lerobot-record \
+ --robot.type=hope_jr_hand \
+ --robot.port=/dev/tty.usbmodem58760432281 \
+ --robot.id=right \
+ --robot.side=right \
+ --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
+ --display_data=false \
+ --dataset.repo_id=nepyope/eval_hopejr \
+ --dataset.single_task="Evaluate hopejr hand policy" \
+ --dataset.num_episodes=10 \
+ --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
+```
diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx
index cfa0a2809..91df14028 100644
--- a/docs/source/il_robots.mdx
+++ b/docs/source/il_robots.mdx
@@ -3,6 +3,7 @@
This tutorial will explain how to train a neural network to control a real robot autonomously.
**You'll learn:**
+
1. How to record and visualize your dataset.
2. How to train a policy using your data and prepare it for evaluation.
3. How to evaluate your policy and visualize the results.
@@ -14,7 +15,10 @@ By following these steps, you'll be able to replicate tasks, such as picking up
@@ -41,7 +45,7 @@ Note that the `id` associated with a robot is used to store the calibration file
```bash
-python -m lerobot.teleoperate \
+lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
@@ -51,6 +55,8 @@ python -m lerobot.teleoperate \
```
+
+
```python
from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader
from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
@@ -74,10 +80,13 @@ while True:
action = teleop_device.get_action()
robot.send_action(action)
```
+
+
The teleoperate command will automatically:
+
1. Identify any missing calibrations and initiate the calibration procedure.
2. Connect the robot and teleop device and start teleoperation.
@@ -92,7 +101,7 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
```bash
-python -m lerobot.teleoperate \
+lerobot-teleoperate \
--robot.type=koch_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
@@ -104,6 +113,8 @@ python -m lerobot.teleoperate \
```
+
+
```python
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader
@@ -134,6 +145,8 @@ while True:
action = teleop_device.get_action()
robot.send_action(action)
```
+
+
@@ -144,11 +157,13 @@ Once you're familiar with teleoperation, you can record your first dataset.
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
Add your token to the CLI by running this command:
+
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
+
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
@@ -159,7 +174,7 @@ Now you can record a dataset. To record 5 episodes and upload your dataset to th
```bash
-python -m lerobot.record \
+lerobot-record \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem585A0076841 \
--robot.id=my_awesome_follower_arm \
@@ -174,6 +189,8 @@ python -m lerobot.record \
```
+
+
```python
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -183,7 +200,7 @@ from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderCo
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
-from lerobot.utils.visualization_utils import _init_rerun
+from lerobot.utils.visualization_utils import init_rerun
from lerobot.record import record_loop
NUM_EPISODES = 5
@@ -220,7 +237,7 @@ dataset = LeRobotDataset.create(
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
-_init_rerun(session_name="recording")
+init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
@@ -270,34 +287,49 @@ robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
```
+
+
#### Dataset upload
-Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
+
+Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. `https://huggingface.co/datasets/${HF_USER}/so101_test`) that you can obtain by running:
+
```bash
echo https://huggingface.co/datasets/${HF_USER}/so101_test
```
+
Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
+You can also push your local dataset to the Hub manually, running:
+
+```bash
+huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
+```
+
#### Record function
The `record` function provides a suite of tools for capturing and managing data during robot operation:
##### 1. Data Storage
+
- Data is stored using the `LeRobotDataset` format and is stored on disk during recording.
- By default, the dataset is pushed to your Hugging Face page after recording.
- To disable uploading, use `--dataset.push_to_hub=False`.
##### 2. Checkpointing and Resuming
+
- Checkpoints are automatically created during recording.
-- If an issue occurs, you can resume by re-running the same command with `--resume=true`.
+- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
- To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters
+
Set the flow of data recording using command-line arguments:
+
- `--dataset.episode_time_s=60`
Duration of each data recording episode (default: **60 seconds**).
- `--dataset.reset_time_s=60`
@@ -306,7 +338,9 @@ Set the flow of data recording using command-line arguments:
Total number of episodes to record (default: **50**).
##### 4. Keyboard Controls During Recording
+
Control the data recording flow using keyboard shortcuts:
+
- Press **Right Arrow (`→`)**: Early stop the current episode or reset time and move to the next.
- Press **Left Arrow (`←`)**: Cancel the current episode and re-record it.
- Press **Escape (`ESC`)**: Immediately stop the session, encode videos, and upload the dataset.
@@ -321,13 +355,14 @@ Avoid adding too much variation too quickly, as it may hinder your results.
If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset.
-
#### Troubleshooting:
+
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
## Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
+
```bash
echo ${HF_USER}/so101_test
```
@@ -341,7 +376,7 @@ You can replay the first episode on your robot with either the command below or
```bash
-python -m lerobot.replay \
+lerobot-replay \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=my_awesome_follower_arm \
@@ -350,6 +385,8 @@ python -m lerobot.replay \
```
+
+
```python
import time
@@ -382,6 +419,8 @@ for idx in range(dataset.num_frames):
robot.disconnect()
```
+
+
@@ -389,9 +428,10 @@ Your robot should replicate movements similar to those you recorded. For example
## Train a policy
-To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
+To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
+
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--dataset.repo_id=${HF_USER}/so101_test \
--policy.type=act \
--output_dir=outputs/train/act_so101_test \
@@ -402,16 +442,18 @@ python -m lerobot.scripts.train \
```
Let's explain the command:
+
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`.
-2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
-5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
+2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
+3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
+4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours. You will find checkpoints in `outputs/train/act_so101_test/checkpoints`.
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy:
+
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
--resume=true
```
@@ -420,18 +462,21 @@ If you do not want to push your model to the hub after training use `--policy.pu
Additionally you can provide extra `tags` or specify a `license` for your model or make the model repo `private` by adding this: `--policy.private=true --policy.tags=\[ppo,rl\] --policy.license=mit`
-#### Train using Collab
-If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act).
+#### Train using Google Colab
+
+If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:
+
```bash
huggingface-cli upload ${HF_USER}/act_so101_test \
outputs/train/act_so101_test/checkpoints/last/pretrained_model
```
You can also upload intermediate checkpoints with:
+
```bash
CKPT=010000
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
@@ -440,12 +485,12 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy
-You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
+You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
```bash
-python -m lerobot.record \
+lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
@@ -461,6 +506,8 @@ python -m lerobot.record \
```
+
+
```python
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -470,13 +517,16 @@ from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerCon
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
-from lerobot.utils.visualization_utils import _init_rerun
+from lerobot.utils.visualization_utils import init_rerun
from lerobot.record import record_loop
+from lerobot.policies.factory import make_processor
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
+HF_MODEL_ID = "/"
+HF_DATASET_ID = "/"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
@@ -488,7 +538,7 @@ robot_config = SO100FollowerConfig(
robot = SO100Follower(robot_config)
# Initialize the policy
-policy = ACTPolicy.from_pretrained("/")
+policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
@@ -497,7 +547,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
- repo_id="/eval_",
+ repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -507,11 +557,17 @@ dataset = LeRobotDataset.create(
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
-_init_rerun(session_name="recording")
+init_rerun(session_name="recording")
# Connect the robot
robot.connect()
+preprocessor, postprocessor = make_processor(
+ policy_cfg=policy,
+ pretrained_path=HF_MODEL_ID,
+ dataset_stats=dataset.meta.stats,
+)
+
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
@@ -521,6 +577,8 @@ for episode_idx in range(NUM_EPISODES):
events=events,
fps=FPS,
policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
@@ -533,9 +591,12 @@ for episode_idx in range(NUM_EPISODES):
robot.disconnect()
dataset.push_to_hub()
```
+
+
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
-1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
+
+1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx
index 048d3147e..9b7d7c111 100644
--- a/docs/source/il_sim.mdx
+++ b/docs/source/il_sim.mdx
@@ -3,6 +3,7 @@
This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning.
**You'll learn:**
+
1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset.
2. How to train a policy using your data.
3. How to evaluate your policy in simulation and visualize the results.
@@ -21,13 +22,38 @@ pip install -e ".[hilserl]"
## Teleoperate and Record a Dataset
-To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
+To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/env_config.json).
-To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
+To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
-If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
+```json
+{
+ "env": {
+ "type": "gym_manipulator",
+ "name": "gym_hil",
+ "task": "PandaPickCubeGamepad-v0",
+ "fps": 10
+ },
+ "dataset": {
+ "repo_id": "your_username/il_gym",
+ "root": null,
+ "task": "pick_cube",
+ "num_episodes_to_record": 30,
+ "replay_episode": null,
+ "push_to_hub": true
+ },
+ "mode": "record",
+ "device": "cuda"
+}
+```
-By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
+Key configuration points:
+
+- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
+- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
+- Ensure `mode` is set to `"record"`
+- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
+- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
Then we can run this command to start:
@@ -35,14 +61,14 @@ Then we can run this command to start:
```bash
-python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
+python -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
```
```bash
-mjpython -m lerobot.scripts.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
+mjpython -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
```
@@ -55,13 +81,21 @@ Note that to teleoperate the robot you have to hold the "Human Take Over Pause P
**Gamepad Controls**
-
+
+
+
+ Gamepad button mapping for robot control and episode management
-
Gamepad button mapping for robot control and episode management
**Keyboard controls**
For keyboard controls use the `spacebar` to enable control and the following keys to move the robot:
+
```bash
Arrow keys: Move in X-Y plane
Shift and Shift_R: Move in Z axis
@@ -74,16 +108,23 @@ For keyboard controls use the `spacebar` to enable control and the following key
If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id.
-
+
+
+
+ Dataset visualizer
-
Dataset visualizer
-
## Train a policy
-To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
+To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
+
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--dataset.repo_id=${HF_USER}/il_gym \
--policy.type=act \
--output_dir=outputs/train/il_sim_test \
@@ -93,25 +134,29 @@ python -m lerobot.scripts.train \
```
Let's explain the command:
+
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`.
-2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
-5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
+2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
+3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
+4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`.
#### Train using Collab
+
If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:
+
```bash
huggingface-cli upload ${HF_USER}/il_sim_test \
outputs/train/il_sim_test/checkpoints/last/pretrained_model
```
You can also upload intermediate checkpoints with:
+
```bash
CKPT=010000
huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
@@ -120,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
## Evaluate your policy in Sim
-To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
+To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/eval_config.json).
-Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
+Here's an example evaluation configuration:
+
+```json
+{
+ "env": {
+ "type": "gym_manipulator",
+ "name": "gym_hil",
+ "task": "PandaPickCubeGamepad-v0",
+ "fps": 10
+ },
+ "dataset": {
+ "repo_id": "your_username/il_sim_dataset",
+ "dataset_root": null,
+ "task": "pick_cube"
+ },
+ "pretrained_policy_name_or_path": "your_username/il_sim_model",
+ "device": "cuda"
+}
+```
+
+Make sure to replace:
+
+- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
+- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
Then you can run this command to visualize your trained policy
@@ -130,23 +198,23 @@ Then you can run this command to visualize your trained policy
```bash
-python -m lerobot.scripts.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
+python -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
```
```bash
-mjpython -m lerobot.scripts.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
+mjpython -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
```
> [!WARNING]
-> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation.
+> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation.
Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim)
> [!TIP]
-> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
+> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/docs/source/implement_your_own_processor.mdx b/docs/source/implement_your_own_processor.mdx
new file mode 100644
index 000000000..5b7d4f78a
--- /dev/null
+++ b/docs/source/implement_your_own_processor.mdx
@@ -0,0 +1,273 @@
+# Implement your own Robot Processor
+
+In this tutorial, you'll learn how to implement your own Robot Processor.
+It begins by exploring the need for a custom processor, then uses the `NormalizerProcessorStep` as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot.
+
+## Why would you need a custom processor?
+
+In most cases, when reading raw data from sensors or when models output actions, you need to process this data to make it compatible with your target system. For example, a common need is normalizing data ranges to make them suitable for neural networks.
+
+LeRobot's `NormalizerProcessorStep` handles this crucial task:
+
+```python
+# Input: raw joint positions in [0, 180] degrees
+raw_action = torch.tensor([90.0, 45.0, 135.0])
+
+# After processing: normalized to [-1, 1] range for model training
+normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=dataset_stats)
+normalized_result = normalizer(transition)
+# ...
+```
+
+Other common processing needs include:
+
+- **Device placement**: Moving tensors between CPU/GPU and converting data types
+- **Format conversion**: Transforming between different data structures
+- **Batching**: Adding/removing batch dimensions for model compatibility
+- **Safety constraints**: Applying limits to robot commands
+
+```python
+# Example pipeline combining multiple processors
+pipeline = PolicyProcessorPipeline([
+ RenameObservationsProcessorStep(rename_map={}),
+ AddBatchDimensionProcessorStep(),
+ NormalizerProcessorStep(features=features, stats=stats),
+ DeviceProcessorStep(device="cuda"),
+ # ...
+])
+```
+
+LeRobot provides a pipeline mechanism to implement sequences of processing steps for both input data and output actions, making it easy to compose these transformations in the right order for optimal performance.
+
+## How to implement your own processor?
+
+We'll use the `NormalizerProcessorStep` as our main example because it demonstrates essential processor patterns including state management, configuration serialization, and tensor handling that you'll commonly need.
+
+Prepare the sequence of processing steps necessary for your problem. A processor step is a class that implements the following methods:
+
+- `__call__`: implements the processing step for the input transition.
+- `get_config`: gets the configuration of the processor step.
+- `state_dict`: gets the state of the processor step.
+- `load_state_dict`: loads the state of the processor step.
+- `reset`: resets the state of the processor step.
+- `feature_contract`: displays the modification to the feature space during the processor step.
+
+### Implement the `__call__` method
+
+The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `NormalizerProcessorStep` works:
+
+```python
+@dataclass
+@ProcessorStepRegistry.register("normalizer_processor")
+class NormalizerProcessorStep(ProcessorStep):
+ """Normalize observations/actions using dataset statistics."""
+
+ features: dict[str, PolicyFeature]
+ norm_map: dict[FeatureType, NormalizationMode]
+ stats: dict[str, dict[str, Any]] | None = None
+ eps: float = 1e-8
+ _tensor_stats: dict = field(default_factory=dict, init=False, repr=False)
+
+ def __post_init__(self):
+ """Convert stats to tensors for efficient computation."""
+ self.stats = self.stats or {}
+ self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=torch.float32)
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ new_transition = transition.copy()
+ # Normalize observations
+ # ...
+ # Normalize action
+ # ...
+ return new_transition
+
+```
+
+See the full implementation in `src/lerobot/processor/normalize_processor.py` for complete details.
+
+**Key principles:**
+
+- **Always use `transition.copy()`** to avoid side effects
+- **Handle both observations and actions** consistently
+- **Separate config from state**: `get_config()` returns JSON-serializable params, `state_dict()` returns tensors
+- **Convert stats to tensors** in `__post_init__()` for efficient computation
+
+### Configuration and State Management
+
+Processors support serialization through three methods that separate configuration from tensor state. The `NormalizerProcessorStep` demonstrates this perfectly - it carries dataset statistics (tensors) in its state, and hyperparameters in its config:
+
+```python
+# Continuing the NormalizerProcessorStep example...
+
+def get_config(self) -> dict[str, Any]:
+ """JSON-serializable configuration (no tensors)."""
+ return {
+ "eps": self.eps,
+ "features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()},
+ "norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()},
+ # ...
+ }
+
+def state_dict(self) -> dict[str, torch.Tensor]:
+ """Tensor state only (e.g., dataset statistics)."""
+ flat: dict[str, torch.Tensor] = {}
+ for key, sub in self._tensor_stats.items():
+ for stat_name, tensor in sub.items():
+ flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
+ return flat
+
+def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
+ """Restore tensor state at runtime."""
+ self._tensor_stats.clear()
+ for flat_key, tensor in state.items():
+ key, stat_name = flat_key.rsplit(".", 1)
+ # Load to processor's configured device
+ self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
+ dtype=torch.float32, device=self.device
+ )
+ # ...
+```
+
+**Usage:**
+
+```python
+# Save (e.g., inside a policy)
+config = normalizer.get_config()
+tensors = normalizer.state_dict()
+
+# Restore (e.g., loading a pretrained policy)
+new_normalizer = NormalizerProcessorStep(**config)
+new_normalizer.load_state_dict(tensors)
+# Now new_normalizer has the same stats and configuration
+```
+
+### Transform features
+
+The `transform_features` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging.
+
+For `NormalizerProcessorStep`, features are typically preserved unchanged since normalization doesn't alter keys or shapes:
+
+```python
+def transform_features(self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """Normalization preserves all feature definitions."""
+ return features # No changes to feature structure
+ # ...
+```
+
+When your processor renames or reshapes data, implement this method to reflect the mapping for downstream components. For example, a simple rename processor:
+
+```python
+def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
+ # Simple renaming
+ if "pixels" in features:
+ features["observation.image"] = features.pop("pixels")
+
+ # Pattern-based renaming
+ for key in list(features.keys()):
+ if key.startswith("env_state."):
+ suffix = key[len("env_state."):]
+ features[f"observation.{suffix}"] = features.pop(key)
+ # ...
+
+ return features
+```
+
+**Key principles:**
+
+- Use `features.pop(old_key)` to remove and get the old feature
+- Use `features[new_key] = old_feature` to add the renamed feature
+- Always return the modified features dictionary
+- Document transformations clearly in the docstring
+
+### Using overrides
+
+You can override step parameters at load-time using `overrides`. This is handy for non-serializable objects or site-specific settings. It works both in policy factories and with `DataProcessorPipeline.from_pretrained(...)`.
+
+**Foundational model adaptation**: This is particularly useful when working with foundational pretrained policies where you rarely have access to the original training statistics. You can inject your own dataset statistics to adapt the normalizer to your specific robot or environment data.
+
+Example: during policy evaluation on the robot, override the device and rename map.
+Use this to run a policy trained on CUDA on a CPU-only robot, or to remap camera keys when the robot uses different names than the dataset.
+
+Direct usage with `from_pretrained`:
+
+```python
+from lerobot.processor import RobotProcessorPipeline
+
+# Load a foundational policy trained on diverse robot data
+# but adapt normalization to your specific robot/environment
+new_stats = LeRobotDataset(repo_id="username/my-dataset").meta.stats
+processor = RobotProcessorPipeline.from_pretrained(
+ "huggingface/foundational-robot-policy", # Pretrained foundation model
+ overrides={
+ "normalizer_processor": {"stats": new_stats}, # Inject your robot's statistics
+ "device_processor": {"device": "cuda:0"}, # registry name for registered steps
+ "rename_processor": {"rename_map": robot_key_map}, # Map your robot's observation keys
+ # ...
+ },
+)
+```
+
+## Best Practices
+
+Based on analysis of all LeRobot processor implementations, here are the key patterns and practices:
+
+### 1. **Safe Data Handling**
+
+Always create copies of input data to avoid unintended side effects. Use `transition.copy()` and `observation.copy()` rather than modifying data in-place. This prevents your processor from accidentally affecting other components in the pipeline.
+
+Check for required data before processing and handle missing data gracefully. If your processor expects certain keys (like `"pixels"` for image processing), validate their presence first. For optional data, use safe access patterns like `transition.get()` and handle `None` values appropriately.
+
+When data validation fails, provide clear, actionable error messages that help users understand what went wrong and how to fix it.
+
+### 2. **Choose Appropriate Base Classes**
+
+LeRobot provides specialized base classes that reduce boilerplate code and ensure consistency. Use `ObservationProcessorStep` when you only need to modify observations, `ActionProcessorStep` for action-only processing, and `RobotActionProcessorStep` specifically for dictionary-based robot actions.
+
+Only inherit directly from `ProcessorStep` when you need full control over the entire transition or when processing multiple transition components simultaneously. The specialized base classes handle the transition management for you and provide type safety.
+
+### 3. **Registration and Naming**
+
+Register your processors with descriptive, namespaced names using `@ProcessorStepRegistry.register()`. Use organization prefixes like `"robotics_lab/safety_clipper"` or `"acme_corp/vision_enhancer"` to avoid naming conflicts. Avoid generic names like `"processor"` or `"step"` that could clash with other implementations.
+
+Good registration makes your processors discoverable and enables clean serialization/deserialization when saving and loading pipelines.
+
+### 4. **State Management Patterns**
+
+Distinguish between configuration parameters (JSON-serializable values) and internal state (tensors, buffers). Use dataclass fields with `init=False, repr=False` for internal state that shouldn't appear in the constructor or string representation.
+
+Implement the `reset()` method to clear internal state between episodes. This is crucial for stateful processors that accumulate data over time, like moving averages or temporal filters.
+
+Remember that `get_config()` should only return JSON-serializable configuration, while `state_dict()` handles tensor state separately.
+
+### 5. **Input Validation and Error Handling**
+
+Validate input types and shapes before processing. Check tensor properties like `dtype` and dimensions to ensure compatibility with your algorithms. For robot actions, verify that required pose components or joint values are present and within expected ranges.
+
+Use early returns for edge cases where no processing is needed. Provide clear, descriptive error messages that include the expected vs. actual data types or shapes. This makes debugging much easier for users.
+
+### 6. **Device and Dtype Awareness**
+
+Design your processors to automatically adapt to the device and dtype of input tensors. Internal tensors (like normalization statistics) should match the input tensor's device and dtype to ensure compatibility with multi-GPU training, mixed precision, and distributed setups.
+
+Implement a `to()` method that moves your processor's internal state to the specified device. Check device/dtype compatibility at runtime and automatically migrate internal state when needed. This pattern enables seamless operation across different hardware configurations without manual intervention.
+
+## Conclusion
+
+You now have all the tools to implement custom processors in LeRobot! The key steps are:
+
+1. **Define your processor** as a dataclass with the required methods (`__call__`, `get_config`, `state_dict`, `load_state_dict`, `reset`, `transform_features`)
+2. **Register it** using `@ProcessorStepRegistry.register("name")` for discoverability
+3. **Integrate it** into a `DataProcessorPipeline` with other processing steps
+4. **Use base classes** like `ObservationProcessorStep` when possible to reduce boilerplate
+5. **Implement device/dtype awareness** to support multi-GPU and mixed precision setups
+
+The processor system is designed to be modular and composable, allowing you to build complex data processing pipelines from simple, focused components. Whether you're preprocessing sensor data for training or post-processing model outputs for robot execution, custom processors give you the flexibility to handle any data transformation your robotics application requires.
+
+Key principles for robust processors:
+
+- **Device/dtype adaptation**: Internal tensors should match input tensors
+- **Clear error messages**: Help users understand what went wrong
+- **Base class usage**: Leverage specialized base classes to reduce boilerplate
+- **Feature contracts**: Declare data structure changes with `transform_features()`
+
+Start simple, test thoroughly, and ensure your processors work seamlessly across different hardware configurations!
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index b8ff56ea7..a2f919e7d 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -1,6 +1,10 @@
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index 51474d8f7..93354c2ee 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -1,49 +1,88 @@
# Installation
-## Install LeRobot
-
-Currently only available from source.
-
-Download our source code:
-```bash
-git clone https://github.com/huggingface/lerobot.git
-cd lerobot
-```
+## Environment Setup
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
+
```bash
conda create -y -n lerobot python=3.10
```
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
+
```bash
conda activate lerobot
```
When using `miniconda`, install `ffmpeg` in your environment:
+
```bash
conda install ffmpeg -c conda-forge
```
> [!TIP]
> This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
-> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
-> ```bash
-> conda install ffmpeg=7.1.1 -c conda-forge
-> ```
-> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
+>
+> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
+>
+> ```bash
+> conda install ffmpeg=7.1.1 -c conda-forge
+> ```
+>
+> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
+
+## Install LeRobot 🤗
+
+### From Source
+
+First, clone the repository and navigate into the directory:
+
+```bash
+git clone https://github.com/huggingface/lerobot.git
+cd lerobot
+```
+
+Then, install the library in editable mode. This is useful if you plan to contribute to the code.
-Install 🤗 LeRobot:
```bash
pip install -e .
```
+### Installation from PyPI
+
+**Core Library:**
+Install the base package with:
+
+```bash
+pip install lerobot
+```
+
+_This installs only the default dependencies._
+
+**Extra Features:**
+To install additional functionality, use one of the following:
+
+```bash
+pip install 'lerobot[all]' # All available features
+pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
+pip install 'lerobot[feetech]' # Feetech motor support
+```
+
+_Replace `[...]` with your desired features._
+
+**Available Tags:**
+For a full list of optional dependencies, see:
+https://pypi.org/project/lerobot/
+
### Troubleshooting
+
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
To install these for linux run:
+
```bash
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
```
+
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
## Optional dependencies
@@ -51,20 +90,26 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`.
### Simulations
+
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
Example:
+
```bash
pip install -e ".[aloha]" # or "[pusht]" for example
```
### Motor Control
+
For Koch v1.1 install the Dynamixel SDK, for SO100/SO101/Moss install the Feetech SDK.
+
```bash
pip install -e ".[feetech]" # or "[dynamixel]" for example
```
### Experiment Tracking
+
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
+
```bash
wandb login
```
diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx
index 18d73d3cd..7e7fe0bff 100644
--- a/docs/source/integrate_hardware.mdx
+++ b/docs/source/integrate_hardware.mdx
@@ -2,37 +2,34 @@
This tutorial will explain how to integrate your own robot design into the LeRobot ecosystem and have it access all of our tools (data collection, control pipelines, policy training and inference).
-To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/lerobot/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it.
+To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it.
## Prerequisites
- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP)
- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation.
-- LeRobot installed in your environment. Follow our [Installation Guide](./installation).
+- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx).
## Choose your motors
If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus interfaces:
-- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/feetech/feetech.py) – for controlling Feetech servos
-- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos
+- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/feetech/feetech.py) – for controlling Feetech servos
+- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos
-Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/motors_bus.py) abstract class to learn about its API.
-For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/lerobot/robots/so101_follower/so101_follower.py)
+Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/motors_bus.py) abstract class to learn about its API.
+For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so101_follower/so101_follower.py)
Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial):
+
- Find an existing SDK in Python (or use bindings to C/C++)
- Or implement a basic communication wrapper (e.g., via pyserial, socket, or CANopen)
You're not alone—many community contributions use custom boards or firmware!
-For Feetech and Dynamixel, we currently support these servos:
- - Feetech:
- - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl`
- - SCS series (protocol 1): `scs0009`
- - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150`
+For Feetech and Dynamixel, we currently support these servos: - Feetech: - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl` - SCS series (protocol 1): `scs0009` - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150`
-If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do.
+If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do.
In the next sections, we'll use a `FeetechMotorsBus` as the motors interface for the examples. Replace it and adapt to your motors if necessary.
@@ -41,6 +38,8 @@ In the next sections, we'll use a `FeetechMotorsBus` as the motors interface for
You’ll first need to specify the config class and a string identifier (`name`) for your robot. If your robot has special needs that you'd like to be able to change easily, it should go here (e.g. port/address, baudrate).
Here, we'll add the port name and one camera by default for our robot:
+
+
```python
from dataclasses import dataclass, field
@@ -64,13 +63,15 @@ class MyCoolRobotConfig(RobotConfig):
}
)
```
+
-Have a look at our [Cameras tutorial](./cameras) to understand how to detect and add your camera.
+[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera.
Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools.
Here we'll create a simple 5-DoF robot with one camera. It could be a simple arm but notice that the `Robot` abstract class does not assume anything on your robot's form factor. You can let you imagination run wild when designing new robots!
+
```python
from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorNormMode
@@ -96,10 +97,11 @@ class MyCoolRobot(Robot):
)
self.cameras = make_cameras_from_configs(config.cameras)
```
+
## Step 2: Define Observation and Action Features
-These two properties define the *interface contract* between your robot and tools that consume it (such as data collection or learning pipelines).
+These two properties define the _interface contract_ between your robot and tools that consume it (such as data collection or learning pipelines).
> [!WARNING]
> Note that these properties must be callable even if the robot is not yet connected, so avoid relying on runtime hardware state to define them.
@@ -109,6 +111,8 @@ These two properties define the *interface contract* between your robot and tool
This property should return a dictionary describing the structure of sensor outputs from your robot. The keys match what `get_observation()` returns, and the values describe either the shape (for arrays/images) or the type (for simple values).
Example for our 5-DoF arm with one camera:
+
+
```python
@property
def _motors_ft(self) -> dict[str, type]:
@@ -130,6 +134,8 @@ def _cameras_ft(self) -> dict[str, tuple]:
def observation_features(self) -> dict:
return {**self._motors_ft, **self._cameras_ft}
```
+
+
In this case, observations consist of a simple dict storing each motor's position and a camera image.
### `action_features`
@@ -137,10 +143,13 @@ In this case, observations consist of a simple dict storing each motor's positio
This property describes the commands your robot expects via `send_action()`. Again, keys must match the expected input format, and values define the shape/type of each command.
Here, we simply use the same joints proprioceptive features (`self._motors_ft`) as with `observation_features`: the action sent will simply the goal position for each motor.
+
+
```python
def action_features(self) -> dict:
return self._motors_ft
```
+
## Step 3: Handle Connection and Disconnection
@@ -150,16 +159,19 @@ These methods should handle opening and closing communication with your hardware
This property should simply reflect that communication with the robot's hardware is established. When this property is `True`, it should be possible to read and write to the hardware using `get_observation()` and `send_action()`.
+
```python
@property
def is_connected(self) -> bool:
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
```
+
### `connect()`
This method should establish communication with the hardware. Moreover, if your robot needs calibration and is not calibrated, it should start a calibration procedure by default. If your robot needs some specific configuration, this should also be called here.
+
```python
def connect(self, calibrate: bool = True) -> None:
self.bus.connect()
@@ -171,25 +183,31 @@ def connect(self, calibrate: bool = True) -> None:
self.configure()
```
+
### `disconnect()`
This method should gracefully terminate communication with the hardware: free any related resources (threads or processes), close ports, etc.
Here, we already handle this in our `MotorsBus` and `Camera` classes so we just need to call their own `disconnect()` methods:
+
+
```python
def disconnect(self) -> None:
self.bus.disconnect()
for cam in self.cameras.values():
cam.disconnect()
```
+
## Step 4: Support Calibration and Configuration
LeRobot supports saving and loading calibration data automatically. This is useful for joint offsets, zero positions, or sensor alignment.
> Note that depending on your hardware, this may not apply. If that's the case, you can simply leave these methods as no-ops:
-> ```python
+
+
+```python
> @property
> def is_calibrated(self) -> bool:
> return True
@@ -202,7 +220,8 @@ LeRobot supports saving and loading calibration data automatically. This is usef
This should reflect whether your robot has the required calibration loaded.
-```python
+```
+python
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
@@ -216,6 +235,8 @@ The goal of the calibration is twofold:
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
+
+
```python
def calibrate(self) -> None:
self.bus.disable_torque()
@@ -245,11 +266,13 @@ def calibrate(self) -> None:
self._save_calibration()
print("Calibration saved to", self.calibration_fpath)
```
+
### `configure()`
Use this to set up any configuration for your hardware (servos control modes, controller gains, etc.). This should usually be run at connection time and be idempotent.
+
```python
def configure(self) -> None:
with self.bus.torque_disabled():
@@ -260,6 +283,7 @@ def configure(self) -> None:
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
```
+
## Step 5: Implement Sensors Reading and Action Sending
@@ -269,6 +293,7 @@ These are the most important runtime functions: the core I/O loop.
Returns a dictionary of sensor values from the robot. These typically include motor states, camera frames, various sensors, etc. In the LeRobot framework, these observations are what will be fed to a policy in order to predict the actions to take. The dictionary keys and structure must match `observation_features`.
+
```python
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
@@ -284,6 +309,7 @@ def get_observation(self) -> dict[str, Any]:
return obs_dict
```
+
### `send_action()`
@@ -291,6 +317,7 @@ Takes a dictionary that matches `action_features`, and sends it to your hardware
For simplicity, we won't be adding any modification of the actions in our example here.
+
```python
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items()}
@@ -300,13 +327,142 @@ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
return action
```
+
## Adding a Teleoperator
-For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/lerobot/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor.
+For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor.
The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods.
+## Using Your Own `LeRobot` Devices 🔌
+
+You can easily extend `lerobot` with your own custom hardware—be it a camera, robot, or teleoperation device—by creating a separate, installable Python package. If you follow a few simple conventions, the `lerobot` command-line tools (like `lerobot-teleop` and `lerobot-record`) will **automatically discover and integrate your creations** without requiring any changes to the `lerobot` source code.
+
+This guide outlines the conventions your plugin must follow.
+
+### The 4 Core Conventions
+
+To ensure your custom device is discoverable, you must adhere to the following four rules.
+
+#### 1\. Create an Installable Package with a Specific Prefix
+
+Your project must be a standard, installable Python package. Crucially, the name of your package (as defined in `pyproject.toml` or `setup.py`) must begin with one of these prefixes:
+
+- `lerobot_robot_` for a robot.
+- `lerobot_camera_` for a camera.
+- `lerobot_teleoperator_` for a teleoperation device.
+
+This prefix system is how `lerobot` automatically finds your plugin in the Python environment.
+
+#### 2\. Follow the `SomethingConfig`/`Something` Naming Pattern
+
+Your device's implementation class must be named after its configuration class, simply by removing the `Config` suffix.
+
+- **Config Class:** `MyAwesomeTeleopConfig`
+- **Device Class:** `MyAwesomeTeleop`
+
+#### 3\. Place Your Files in a Predictable Structure
+
+The device class (`MyAwesomeTeleop`) must be located in a predictable module relative to its configuration class (`MyAwesomeTeleopConfig`). `lerobot` will automatically search in these locations:
+
+- In the **same module** as the config class.
+- In a **submodule named after the device** (e.g., `my_awesome_teleop.py`).
+
+The recommended and simplest structure is to place them in separate, clearly named files within the same directory.
+
+#### 4\. Expose Classes in `__init__.py`
+
+Your package's `__init__.py` file should import and expose both the configuration and the device classes, making them easily accessible.
+
+### Putting It All Together: A Complete Example
+
+Let's create a new teleoperator called `my_awesome_teleop`.
+
+#### Directory Structure
+
+Here is what the project folder should look like. The package name, `lerobot_teleoperator_my_awesome_teleop`, follows **Convention \#1**.
+
+```
+lerobot_teleoperator_my_awesome_teleop/
+├── pyproject.toml # (or setup.py) lists lerobot as a dependency
+└── lerobot_teleoperator_my_awesome_teleop/
+ ├── __init__.py
+ ├── config_my_awesome_teleop.py
+ └── my_awesome_teleop.py
+```
+
+#### File Contents
+
+- **`config_my_awesome_teleop.py`**: Defines the configuration class. Note the `Config` suffix (**Convention \#2**).
+
+ ```python
+ from dataclasses import dataclass
+
+ from lerobot.teleoperators.config import TeleoperatorConfig
+
+ @TeleoperatorConfig.register_subclass("my_awesome_teleop")
+ @dataclass
+ class MyAwesomeTeleopConfig(TeleoperatorConfig):
+ # Your configuration fields go here
+ port: str = "192.168.1.1"
+ ```
+
+- **`my_awesome_teleop.py`**: Implements the device. The class name `MyAwesomeTeleop` matches its config class name (**Convention \#2**). This file structure adheres to **Convention \#3**.
+
+ ```python
+ from lerobot.teleoperators.teleoperator import Teleoperator
+
+ from .config_my_awesome_teleop import MyAwesomeTeleopConfig
+
+ class MyAwesomeTeleop(Teleoperator):
+ config_class = MyAwesomeTeleopConfig
+ name = "my_awesome_teleop"
+
+ def __init__(self, config: MyAwesomeTeleopConfig):
+ super().__init__(config)
+ self.config = config
+
+ # Your device logic (e.g., connect) goes here
+ ```
+
+- **`__init__.py`**: Exposes the key classes (**Convention \#4**).
+
+ ```python
+ from .config_my_awesome_teleop import MyAwesomeTeleopConfig
+ from .my_awesome_teleop import MyAwesomeTeleop
+ ```
+
+### Installation and Usage
+
+1. **Install your new plugin in your Python environment.** You can install your local plugin package using `pip`'s editable mode or from PyPi.
+
+ ```bash
+ # Locally
+ # Navigate to your plugin's root directory and install it
+ cd lerobot_teleoperator_my_awesome_teleop
+ pip install -e .
+
+ # From PyPi
+ pip install lerobot_teleoperator_my_awesome_teleop
+ ```
+
+2. **Use it directly from the command line.** Now, you can use your custom device by referencing its type.
+
+ ```bash
+ lerobot-teleoperate --teleop.type=my_awesome_teleop \
+ # other arguments
+ ```
+
+And that's it\! Your custom device is now fully integrated.
+
+### Looking for an example ?
+
+Check out these two packages from the community:
+
+- https://github.com/SpesRobotics/lerobot-robot-xarm
+- https://github.com/SpesRobotics/lerobot-teleoperator-teleop
+
## Wrapping Up
Once your robot class is complete, you can leverage the LeRobot ecosystem:
diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx
new file mode 100644
index 000000000..308edbb3b
--- /dev/null
+++ b/docs/source/introduction_processors.mdx
@@ -0,0 +1,314 @@
+# Introduction to Processors
+
+In robotics, there's a fundamental mismatch between the data that robots and humans produce and what machine learning models expect.
+Robots output raw sensor data like camera images and joint positions that need normalization, batching, and device placement before models can process them.
+Language instructions from humans must be tokenized into numerical representations, and different robots use different coordinate systems that need standardization.
+
+The challenge extends to model outputs as well.
+Models might output end-effector positions while robots need joint-space commands, or teleoperators produce relative movements while robots expect absolute commands.
+Model predictions are often normalized and need conversion back to real-world scales.
+
+Cross-domain translation adds another layer of complexity.
+Training data from one robot setup needs adaptation for deployment on different hardware, models trained with specific camera configurations must work with new arrangements, and datasets with different naming conventions need harmonization.
+
+**That's where processors come in.** They serve as universal translators that bridge these gaps, ensuring seamless data flow from sensors to models to actuators.
+Processors handle all the preprocessing and postprocessing steps needed to convert raw environment data into model-ready inputs and vice versa.
+
+This means that your favorite policy can be used like this:
+
+```python
+import torch
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.policies.your_policy import YourPolicy
+from lerobot.processor.pipeline import RobotProcessorPipeline, PolicyProcessorPipeline
+dataset = LeRobotDataset("hf_user/dataset", episodes=[0])
+sample = dataset[10]
+
+model = YourPolicy.from_pretrained(
+ "hf_user/model",
+)
+model.eval()
+model.to("cuda")
+preprocessor, postprocessor = make_pre_post_processors(model.config, pretrained_path="hf_user/model", dataset_stats=dataset.meta.stats)
+
+preprocessed_sample = preprocessor(sample)
+action = model.select_action(preprocessed_sample)
+postprocessed_action = postprocessor(action)
+```
+
+## What are Processors?
+
+In robotics, data comes in many forms: images from cameras, joint positions from sensors, text instructions from users, and more. Each type of data requires specific transformations before a model can use it effectively. Models need this data to be:
+
+- **Normalized**: Scaled to appropriate ranges for neural network processing
+- **Batched**: Organized with proper dimensions for batch processing
+- **Tokenized**: Text converted to numerical representations
+- **Device-placed**: Moved to the right hardware (CPU/GPU)
+- **Type-converted**: Cast to appropriate data types
+
+Processors handle these transformations through composable, reusable steps that can be chained together into pipelines. Think of them as a modular assembly line where each station performs a specific transformation on your data.
+
+## Core Concepts
+
+### EnvTransition: The Universal Data Container
+
+The `EnvTransition` is the fundamental data structure that flows through all processors.
+It's a typed dictionary that represents a complete robot-environment interaction:
+
+- **OBSERVATION**: All sensor data (images, states, proprioception)
+- **ACTION**: The action to execute or that was executed
+- **REWARD**: Reinforcement learning signal
+- **DONE/TRUNCATED**: Episode boundary indicators
+- **INFO**: Arbitrary metadata
+- **COMPLEMENTARY_DATA**: Task descriptions, indices, padding flags, inter-step data
+
+### ProcessorStep: The Building Block
+
+A `ProcessorStep` is a single transformation unit that processes transitions. It's an abstract base class with two required methods:
+
+```python
+from lerobot.processor import ProcessorStep, EnvTransition
+
+class MyProcessorStep(ProcessorStep):
+ """Example processor step - inherit and implement abstract methods."""
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Transform the transition - REQUIRED abstract method."""
+ # Your processing logic here
+ return transition
+
+ def transform_features(self, features):
+ """Declare how this step transforms feature shapes/types - REQUIRED abstract method."""
+ return features # Most processors return features unchanged
+```
+
+`__call__` is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`.
+
+`transform_features` is used to declare how this step transforms feature shapes/types.
+
+### DataProcessorPipeline: The Generic Orchestrator
+
+The `DataProcessorPipeline[TInput, TOutput]` chains multiple `ProcessorStep` instances together:
+
+```python
+from lerobot.processor import RobotProcessorPipeline, PolicyProcessorPipeline
+
+# For robot hardware (unbatched data)
+robot_processor = RobotProcessorPipeline[RobotAction, RobotAction](
+ steps=[step1, step2, step3],
+ name="robot_pipeline"
+)
+
+# For model training/inference (batched data)
+policy_processor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=[step1, step2, step3],
+ name="policy_pipeline"
+)
+```
+
+## RobotProcessorPipeline vs PolicyProcessorPipeline
+
+The key distinction is in the data structures they handle:
+
+| Aspect | RobotProcessorPipeline | PolicyProcessorPipeline |
+| --------------- | -------------------------------------------- | ---------------------------------------- |
+| **Input** | `dict[str, Any]` - Individual robot values | `dict[str, Any]` - Batched tensors |
+| **Output** | `dict[str, Any]` - Individual robot commands | `torch.Tensor` - Policy predictions |
+| **Use Case** | Real-time robot control | Model training/inference |
+| **Data Format** | Unbatched, heterogeneous | Batched, homogeneous |
+| **Examples** | `{"joint_1": 0.5}` | `{"observation.state": tensor([[0.5]])}` |
+
+**Use `RobotProcessorPipeline`** for robot hardware interfaces:
+
+```python
+# Robot data structures: dict[str, Any] for observations and actions
+robot_obs: dict[str, Any] = {
+ "joint_1": 0.5, # Individual joint values
+ "joint_2": -0.3,
+ "camera_0": image_array # Raw camera data
+}
+
+robot_action: dict[str, Any] = {
+ "joint_1": 0.2, # Target joint positions
+ "joint_2": 0.1,
+ "gripper": 0.8
+}
+```
+
+**Use `PolicyProcessorPipeline`** for model training and batch processing:
+
+```python
+# Policy data structures: batch dicts and tensors
+policy_batch: dict[str, Any] = {
+ "observation.state": torch.tensor([[0.5, -0.3]]), # Batched states
+ "observation.images.camera0": torch.tensor(...), # Batched images
+ "action": torch.tensor([[0.2, 0.1, 0.8]]) # Batched actions
+}
+
+policy_action: torch.Tensor = torch.tensor([[0.2, 0.1, 0.8]]) # Model output tensor
+```
+
+## Converter Functions
+
+LeRobot provides converter functions to bridge different data formats in `lerobot.processor.converters`. These functions handle the crucial translations between robot hardware data structures, policy model formats, and the internal `EnvTransition` representation that flows through processor pipelines.
+
+| Category | Function | Description |
+| ------------------------------ | ----------------------------- | ------------------------------- |
+| **Robot Hardware Converters** | `robot_action_to_transition` | Robot dict → EnvTransition |
+| | `observation_to_transition` | Robot obs → EnvTransition |
+| | `transition_to_robot_action` | EnvTransition → Robot dict |
+| **Policy/Training Converters** | `batch_to_transition` | Batch dict → EnvTransition |
+| | `transition_to_batch` | EnvTransition → Batch dict |
+| | `policy_action_to_transition` | Policy tensor → EnvTransition |
+| | `transition_to_policy_action` | EnvTransition → Policy tensor |
+| **Utilities** | `create_transition` | Build transitions with defaults |
+| | `identity_transition` | Pass-through converter |
+
+The key insight is that **robot hardware converters** work with individual values and dictionaries, while **policy/training converters** work with batched tensors and model outputs. The converter functions automatically handle the structural differences, so your processor steps can focus on the core transformations without worrying about data format compatibility.
+
+## Processor Examples
+
+The following examples demonstrate real-world processor configurations for policy training and inference.
+
+Here is an example processor for policy training and inference:
+
+```python
+# Training data preprocessing (optimized order for GPU performance)
+training_preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=[
+ RenameObservationsProcessorStep(rename_map={}), # Standardize keys
+ AddBatchDimensionProcessorStep(), # Add batch dims
+ TokenizerProcessorStep(tokenizer_name="...", ...), # Tokenize language
+ DeviceProcessorStep(device="cuda"), # Move to GPU first
+ NormalizerProcessorStep(features=..., stats=...), # Normalize on GPU
+ ]
+)
+
+# Model output postprocessing
+training_postprocessor = PolicyProcessorPipeline[torch.Tensor, torch.Tensor](
+ steps=[
+ DeviceProcessorStep(device="cpu"), # Move to CPU
+ UnnormalizerProcessorStep(features=..., stats=...), # Denormalize
+ ]
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+)
+```
+
+### An interaction between a robot and a policy with processors
+
+The most common real-world scenario combines both pipeline types robot hardware generates observations that need policy processing, and policy outputs need robot-compatible postprocessing:
+
+```python
+# Real deployment: Robot sensors → Model → Robot commands
+with torch.no_grad():
+ while not done:
+ raw_obs = robot.get_observation() # dict[str, Any]
+
+ # Add your robot observation to policy observation processor
+
+ policy_input = policy_preprocessor(raw_obs) # Batched dict
+
+ policy_output = policy.select_action(policy_input) # Policy tensor
+
+ policy_action = policy_postprocessor(policy_output)
+
+ # Add your robot action to policy action processor
+
+ robot.send_action(policy_action)
+```
+
+## Feature Contracts: Shape and Type Transformation
+
+Processors don't just transform data - they can also **change the data structure itself**. The `transform_features()` method declares these changes, which is crucial for dataset recording and policy creation.
+
+### Why Feature Contracts Matter
+
+When building datasets or policies, LeRobot needs to know:
+
+- **What data fields will exist** after processing
+- **What shapes and types** each field will have
+- **How to configure models** for the expected data structure
+
+```python
+# Example: A processor that adds velocity to observations
+class VelocityProcessor(ObservationProcessorStep):
+ def observation(self, obs):
+ new_obs = obs.copy()
+ if "observation.state" in obs:
+ # concatenate computed velocity field to the state
+ new_obs["observation.state"] = self._compute_velocity(obs["observation.state"])
+ return new_obs
+
+ def transform_features(self, features):
+ """Declare the new velocity field we're adding."""
+ state_feature = features[PipelineFeatureType.OBSERVATION].get("observation.state")
+ if state_feature:
+ double_shape = (state_feature.shape[0] * 2,) if state_feature.shape else (2,)
+ features[PipelineFeatureType.OBSERVATION]["observation.state"] = PolicyFeature(
+ type=FeatureType.STATE, shape=double_shape
+ )
+ return features
+```
+
+### Feature Specification Functions
+
+`create_initial_features()` and `aggregate_pipeline_dataset_features()` solve a critical dataset creation problem: determining the exact final data structure before any data is processed.
+Since processor pipelines can add new features (like velocity fields), change tensor shapes (like cropping images), or rename keys, datasets need to know the complete output specification upfront to allocate proper storage and define schemas.
+These functions work together by starting with robot hardware specifications (`create_initial_features()`) then simulating the entire pipeline transformation (`aggregate_pipeline_dataset_features()`) to compute the final feature dictionary that gets passed to `LeRobotDataset.create()`, ensuring perfect alignment between what processors output and what datasets expect to store.
+
+```python
+from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
+
+# Start with robot's raw features
+initial_features = create_initial_features(
+ observation=robot.observation_features, # {"joint_1.pos": float, "camera_0": (480,640,3)}
+ action=robot.action_features # {"joint_1.pos": float, "gripper.pos": float}
+)
+
+# Apply processor pipeline to compute final features
+final_features = aggregate_pipeline_dataset_features(
+ pipeline=my_processor_pipeline,
+ initial_features=initial_features,
+ use_videos=True
+)
+
+# Use for dataset creation
+dataset = LeRobotDataset.create(
+ repo_id="my_dataset",
+ features=final_features, # Knows exactly what data to expect
+ ...
+)
+```
+
+## Common Processor Steps
+
+LeRobot provides many registered processor steps. Here are the most commonly used core processors:
+
+### Essential Processors
+
+- **`normalizer_processor`**: Normalize observations/actions using dataset statistics (mean/std or min/max)
+- **`device_processor`**: Move tensors to CPU/GPU with optional dtype conversion
+- **`to_batch_processor`**: Add batch dimensions to transitions for model compatibility
+- **`rename_observations_processor`**: Rename observation keys using mapping dictionaries
+- **`tokenizer_processor`**: Tokenize natural language task descriptions into tokens and attention masks
+
+### Next Steps
+
+- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps
+- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines
+- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns
+
+## Summary
+
+Processors solve the data translation problem in robotics by providing:
+
+- **Modular transformations**: Composable, reusable processing steps
+- **Type safety**: Generic pipelines with compile-time checking
+- **Performance optimization**: GPU-accelerated operations
+- **Robot/Policy distinction**: Separate pipelines for different data structures
+- **Comprehensive ecosystem**: 30+ registered processors for common tasks
+
+The key insight: `RobotProcessorPipeline` handles unbatched robot hardware data, while `PolicyProcessorPipeline` handles batched model data. Choose the right tool for your data structure!
diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx
deleted file mode 120000
index 5383518b3..000000000
--- a/docs/source/koch.mdx
+++ /dev/null
@@ -1 +0,0 @@
-../../src/lerobot/robots/koch_follower/koch.mdx
\ No newline at end of file
diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx
new file mode 100644
index 000000000..813b9bd67
--- /dev/null
+++ b/docs/source/koch.mdx
@@ -0,0 +1,283 @@
+# Koch v1.1
+
+In the steps below, we explain how to assemble the Koch v1.1 robot.
+
+## Order and assemble the parts
+
+Follow the sourcing and assembling instructions provided in this [README](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below.
+
+For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk).
+
+> [!WARNING]
+> Since the production of this video, we simplified the configuration phase. Because of this, two things differ from the instructions in that video:
+>
+> - Don't plug in all the motor cables right away and wait to be instructed to do so in [Configure the motors](#configure-the-motors).
+> - Don't screw in the controller board (PCB) to the base right away and wait for being instructed to do so in [Configure the motors](#configure-the-motors).
+
+## Install LeRobot 🤗
+
+To install LeRobot follow, our [Installation Guide](./installation)
+
+In addition to these instructions, you need to install the Dynamixel SDK:
+
+```bash
+pip install -e ".[dynamixel]"
+```
+
+## Configure the motors
+
+### 1. Find the USB ports associated with each arm
+
+To find the port for each bus servo adapter, run this script:
+
+```bash
+lerobot-find-port
+```
+
+
+
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
+Remove the USB cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/tty.usbmodem575E0032081
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
+
+
+
+
+On Linux, you might need to give access to the USB ports by running:
+
+```bash
+sudo chmod 666 /dev/ttyACM0
+sudo chmod 666 /dev/ttyACM1
+```
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/ttyACM0', '/dev/ttyACM1']
+Remove the usb cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/ttyACM1
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm.
+
+
+
+
+### 2. Set the motors ids and baudrates
+
+Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate.
+
+To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once.
+
+If you are repurposing motors from another robot, you will probably also need to perform this step, as the ids and baudrate likely won't match.
+
+#### Follower
+
+Connect the usb cable from your computer and the 5V power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter.
+
+For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm.
+
+
+
+
+```bash
+lerobot-setup-motors \
+ --robot.type=koch_follower \
+ --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
+```
+
+
+
+
+
+```python
+from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
+
+config = KochFollowerConfig(
+ port="/dev/tty.usbmodem575E0031751",
+ id="my_awesome_follower_arm",
+)
+follower = KochFollower(config)
+follower.setup_motors()
+```
+
+
+
+
+
+You should see the following instruction.
+
+```
+Connect the controller board to the 'gripper' motor only and press enter.
+```
+
+As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor.
+
+
+Troubleshooting
+
+If you get an error at that point, check your cables and make sure they are plugged in properly:
+
+
+
Power supply
+
USB cable between your computer and the controller board
+
The 3-pin cable from the controller board to the motor
+
+
+If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
+
+
+
+You should then see the following message:
+
+```
+'gripper' motor id set to 6
+```
+
+Followed by the next instruction:
+
+```
+Connect the controller board to the 'wrist_roll' motor only and press enter.
+```
+
+You can disconnect the 3-pin cable from the controller board but you can leave it connected to the gripper motor on the other end as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one.
+
+Repeat the operation for each motor as instructed.
+
+> [!TIP]
+> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board.
+
+When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm.
+
+#### Leader
+
+Do the same steps for the leader arm but modify the command or script accordingly.
+
+
+
+
+```bash
+lerobot-setup-motors \
+ --teleop.type=koch_leader \
+ --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step
+```
+
+
+
+
+
+```python
+from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
+
+config = KochLeaderConfig(
+ port="/dev/tty.usbmodem575E0031751",
+ id="my_awesome_leader_arm",
+)
+leader = KochLeader(config)
+leader.setup_motors()
+```
+
+
+
+
+
+## Calibrate
+
+Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
+The calibration process is very important because it allows a neural network trained on one robot to work on another.
+
+#### Follower
+
+Run the following command or API example to calibrate the follower arm:
+
+
+
+
+```bash
+lerobot-calibrate \
+ --robot.type=koch_follower \
+ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
+ --robot.id=my_awesome_follower_arm # <- Give the robot a unique name
+```
+
+
+
+
+
+```python
+from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower
+
+config = KochFollowerConfig(
+ port="/dev/tty.usbmodem585A0076891",
+ id="my_awesome_follower_arm",
+)
+
+follower = KochFollower(config)
+follower.connect(calibrate=False)
+follower.calibrate()
+follower.disconnect()
+```
+
+
+
+
+
+We unified the calibration method for most robots. Thus, the calibration steps for this Koch arm are the same as the steps for the SO100 and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video).
+
+#### Leader
+
+Do the same steps to calibrate the leader arm, run the following command or API example:
+
+
+
+
+```bash
+lerobot-calibrate \
+ --teleop.type=koch_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
+ --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
+```
+
+
+
+
+
+```python
+from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader
+
+config = KochLeaderConfig(
+ port="/dev/tty.usbmodem575E0031751",
+ id="my_awesome_leader_arm",
+)
+
+leader = KochLeader(config)
+leader.connect(calibrate=False)
+leader.calibrate()
+leader.disconnect()
+```
+
+
+
+
+
+Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
+
+> [!TIP]
+> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx
deleted file mode 120000
index afc43077e..000000000
--- a/docs/source/lekiwi.mdx
+++ /dev/null
@@ -1 +0,0 @@
-../../src/lerobot/robots/lekiwi/lekiwi.mdx
\ No newline at end of file
diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx
new file mode 100644
index 000000000..875394d71
--- /dev/null
+++ b/docs/source/lekiwi.mdx
@@ -0,0 +1,337 @@
+# LeKiwi
+
+In the steps below, we explain how to assemble the LeKiwi mobile robot.
+
+## Source the parts
+
+Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts.
+And advise if it's your first time printing or if you don't own a 3D printer.
+
+### Wired version
+
+If you have the **wired** LeKiwi version, you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating.
+
+## Install software on Pi
+
+Now we have to set up the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
+
+### Install OS
+
+For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
+
+### Setup SSH
+
+After setting up your Pi, you should enable and set up [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can log in to the Pi from your laptop without requiring a screen, keyboard, and mouse on the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or, if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
+
+### Install LeRobot on Pi 🤗
+
+On your Raspberry Pi install LeRobot using our [Installation Guide](./installation)
+
+In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your Pi:
+
+```bash
+pip install -e ".[lekiwi]"
+```
+
+## Install LeRobot locally
+
+If you already have installed LeRobot on your laptop/pc you can skip this step; otherwise, please follow along as we do the same steps we did on the Pi.
+
+Follow our [Installation Guide](./installation)
+
+In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your laptop/pc:
+
+```bash
+pip install -e ".[lekiwi]"
+```
+
+Great :hugs:! You are now done installing LeRobot, and we can begin assembling the SO100/SO101 arms and the mobile base :robot:.
+Every time you now want to use LeRobot, you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
+
+# Step-by-Step Assembly Instructions
+
+First, we will assemble the two SO100/SO101 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base. The instructions for assembling can be found on these two pages:
+
+- [Assemble SO101](./so101#step-by-step-assembly-instructions)
+- [Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi/blob/main/Assembly.md)
+
+### Find the USB ports associated with motor board
+
+To find the port for each bus servo adapter, run this script:
+
+```bash
+lerobot-find-port
+```
+
+
+
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/tty.usbmodem575E0032081']
+Remove the USB cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/tty.usbmodem575E0032081
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your board.
+
+
+
+
+On Linux, you might need to give access to the USB ports by running:
+
+```bash
+sudo chmod 666 /dev/ttyACM0
+sudo chmod 666 /dev/ttyACM1
+```
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/ttyACM0']
+Remove the usb cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/ttyACM0
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/ttyACM0` corresponding to your board.
+
+
+
+
+### Configure motors
+
+The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9.
+
+You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7)
+
+```bash
+lerobot-setup-motors \
+ --robot.type=lekiwi \
+ --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step
+```
+
+
+
+### Troubleshoot communication
+
+If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
+
+#### 1. Verify IP Address Configuration
+
+Make sure that the correct IP for the Pi is used in the commands or in your code. To check the Raspberry Pi's IP address, run (on the Pi command line):
+
+```bash
+hostname -I
+```
+
+#### 2. Check if Pi is reachable from laptop/pc
+
+Try pinging the Raspberry Pi from your laptop:
+
+```bach
+ping
+```
+
+If the ping fails:
+
+- Ensure the Pi is powered on and connected to the same network.
+- Check if SSH is enabled on the Pi.
+
+#### 3. Try SSH connection
+
+If you can't SSH into the Pi, it might not be properly connected. Use:
+
+```bash
+ssh @
+```
+
+If you get a connection error:
+
+- Ensure SSH is enabled on the Pi by running:
+ ```bash
+ sudo raspi-config
+ ```
+ Then navigate to: **Interfacing Options -> SSH** and enable it.
+
+### Calibration
+
+Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
+The calibration process is very important because it allows a neural network trained on one robot to work on another.
+
+### Calibrate follower arm (on mobile base)
+
+Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
+
+```bash
+lerobot-calibrate \
+ --robot.type=lekiwi \
+ --robot.id=my_awesome_kiwi # <- Give the robot a unique name
+```
+
+We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video).
+
+### Wired version
+
+If you have the **wired** LeKiwi version, please run all commands on your laptop.
+
+### Calibrate leader arm
+
+Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the following command of API example on your laptop:
+
+
+
+
+```bash
+lerobot-calibrate \
+ --teleop.type=so100_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
+ --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
+```
+
+
+
+
+
+```python
+from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
+
+config = SO100LeaderConfig(
+ port="/dev/tty.usbmodem58760431551",
+ id="my_awesome_leader_arm",
+)
+
+leader = SO100Leader(config)
+leader.connect(calibrate=False)
+leader.calibrate()
+leader.disconnect()
+```
+
+
+
+
+
+## Teleoperate LeKiwi
+
+> [!TIP]
+> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
+
+To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and this command:
+
+```bash
+python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi
+```
+
+Then on your laptop, also run `conda activate lerobot` and run the API example, make sure you set the correct `remote_ip` and `port` in `examples/lekiwi/teleoperate.py`.
+
+```bash
+python examples/lekiwi/teleoperate.py
+```
+
+You should see on your laptop something like this: `[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
+
+| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
+| ---------- | ------------------ | ---------------------- |
+| Fast | 0.4 | 90 |
+| Medium | 0.25 | 60 |
+| Slow | 0.1 | 30 |
+
+| Key | Action |
+| --- | -------------- |
+| W | Move forward |
+| A | Move left |
+| S | Move backward |
+| D | Move right |
+| Z | Turn left |
+| X | Turn right |
+| R | Increase speed |
+| F | Decrease speed |
+
+> [!TIP]
+> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiClientConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/lekiwi/config_lekiwi.py).
+
+### Wired version
+
+If you have the **wired** LeKiwi version, please run all commands on your laptop.
+
+## Record a dataset
+
+Once you're familiar with teleoperation, you can record your first dataset.
+
+We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
+
+Add your token to the CLI by running this command:
+
+```bash
+huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
+```
+
+Then store your Hugging Face repository name in a variable:
+
+```bash
+HF_USER=$(huggingface-cli whoami | head -n 1)
+echo $HF_USER
+```
+
+Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`.
+
+```bash
+python examples/lekiwi/record.py
+```
+
+#### Dataset upload
+
+Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
+
+```bash
+echo https://huggingface.co/datasets/${HF_USER}/so101_test
+```
+
+Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
+
+You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
+
+#### Tips for gathering data
+
+Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images.
+
+In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions.
+
+Avoid adding too much variation too quickly, as it may hinder your results.
+
+If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset.
+
+#### Troubleshooting:
+
+- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
+
+## Replay an episode
+
+To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index.
+
+```bash
+python examples/lekiwi/replay.py
+```
+
+Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./il_robots)
+
+## Evaluate your policy
+
+To evaluate your policy run the `evaluate.py` API example, make sure to change `remote_ip`, `port`, model..
+
+```bash
+python examples/lekiwi/evaluate.py
+```
+
+> [!TIP]
+> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx
new file mode 100644
index 000000000..cf1942fdc
--- /dev/null
+++ b/docs/source/lerobot-dataset-v3.mdx
@@ -0,0 +1,281 @@
+# LeRobotDataset v3.0
+
+`LeRobotDataset v3.0` is a standardized format for robot learning data. It provides unified access to multi-modal time-series data, sensorimotor signals and multi‑camera video, as well as rich metadata for indexing, search, and visualization on the Hugging Face Hub.
+
+This docs will guide you to:
+
+- Understand the v3.0 design and directory layout
+- Record a dataset and push it to the Hub
+- Load datasets for training with `LeRobotDataset`
+- Stream datasets without downloading using `StreamingLeRobotDataset`
+- Apply image transforms for data augmentation during training
+- Migrate existing `v2.1` datasets to `v3.0`
+
+## What’s new in `v3`
+
+- **File-based storage**: Many episodes per Parquet/MP4 file (v2 used one file per episode).
+- **Relational metadata**: Episode boundaries and lookups are resolved through metadata, not filenames.
+- **Hub-native streaming**: Consume datasets directly from the Hub with `StreamingLeRobotDataset`.
+- **Lower file-system pressure**: Fewer, larger files ⇒ faster initialization and fewer issues at scale.
+- **Unified organization**: Clean directory layout with consistent path templates across data and videos.
+
+## Installation
+
+`LeRobotDataset v3.0` will be included in `lerobot >= 0.4.0`.
+
+Until that stable release, you can use the main branch by following the [build from source instructions](./installation#from-source).
+
+## Record a dataset
+
+Run the command below to record a dataset with the SO-101 and push to the Hub:
+
+```bash
+lerobot-record \
+ --robot.type=so101_follower \
+ --robot.port=/dev/tty.usbmodem585A0076841 \
+ --robot.id=my_awesome_follower_arm \
+ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
+ --teleop.type=so101_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \
+ --teleop.id=my_awesome_leader_arm \
+ --display_data=true \
+ --dataset.repo_id=${HF_USER}/record-test \
+ --dataset.num_episodes=5 \
+ --dataset.single_task="Grab the black cube"
+```
+
+See the [recording guide](./il_robots#record-a-dataset) for more details.
+
+## Format design
+
+A core v3 principle is **decoupling storage from the user API**: data is stored efficiently (few large files), while the public API exposes intuitive episode-level access.
+
+`v3` has three pillars:
+
+1. **Tabular data**: Low‑dimensional, high‑frequency signals (states, actions, timestamps) stored in **Apache Parquet**. Access is memory‑mapped or streamed via the `datasets` stack.
+2. **Visual data**: Camera frames concatenated and encoded into **MP4**. Frames from the same episode are grouped; videos are sharded per camera for practical sizes.
+3. **Metadata**: JSON/Parquet records describing schema (feature names, dtypes, shapes), frame rates, normalization stats, and **episode segmentation** (start/end offsets into shared Parquet/MP4 files).
+
+> To scale to millions of episodes, tabular rows and video frames from multiple episodes are **concatenated** into larger files. Episode‑specific views are reconstructed **via metadata**, not file boundaries.
+
+
+
+
+
+ From episode‑based to file‑based datasets
+
+
+
+
+### Directory layout (simplified)
+
+- **`meta/info.json`**: canonical schema (features, shapes/dtypes), FPS, codebase version, and **path templates** to locate data/video shards.
+- **`meta/stats.json`**: global feature statistics (mean/std/min/max) used for normalization; exposed as `dataset.meta.stats`.
+- **`meta/tasks.jsonl`**: natural‑language task descriptions mapped to integer IDs for task‑conditioned policies.
+- **`meta/episodes/`**: per‑episode records (lengths, tasks, offsets) stored as **chunked Parquet** for scalability.
+- **`data/`**: frame‑by‑frame **Parquet** shards; each file typically contains **many episodes**.
+- **`videos/`**: **MP4** shards per camera; each file typically contains **many episodes**.
+
+## Load a dataset for training
+
+`LeRobotDataset` returns Python dictionaries of PyTorch tensors and integrates with `torch.utils.data.DataLoader`. Here is a code example showing its use:
+
+```python
+import torch
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+
+repo_id = "yaak-ai/L2D-v3"
+
+# 1) Load from the Hub (cached locally)
+dataset = LeRobotDataset(repo_id)
+
+# 2) Random access by index
+sample = dataset[100]
+print(sample)
+# {
+# 'observation.state': tensor([...]),
+# 'action': tensor([...]),
+# 'observation.images.front_left': tensor([C, H, W]),
+# 'timestamp': tensor(1.234),
+# ...
+# }
+
+# 3) Temporal windows via delta_timestamps (seconds relative to t)
+delta_timestamps = {
+ "observation.images.front_left": [-0.2, -0.1, 0.0] # 0.2s and 0.1s before current frame
+}
+
+dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
+
+# Accessing an index now returns a stack for the specified key(s)
+sample = dataset[100]
+print(sample["observation.images.front_left"].shape) # [T, C, H, W], where T=3
+
+# 4) Wrap with a DataLoader for training
+batch_size = 16
+data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+for batch in data_loader:
+ observations = batch["observation.state"].to(device)
+ actions = batch["action"].to(device)
+ images = batch["observation.images.front_left"].to(device)
+ # model.forward(batch)
+```
+
+## Stream a dataset (no downloads)
+
+Use `StreamingLeRobotDataset` to iterate directly from the Hub without local copies. This allows to stream large datasets without the need to downloading them onto disk or loading them onto memory, and is a key feature of the new dataset format.
+
+```python
+from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
+
+repo_id = "yaak-ai/L2D-v3"
+dataset = StreamingLeRobotDataset(repo_id) # streams directly from the Hub
+```
+
+
+
+
+
+ Stream directly from the Hub for on‑the‑fly training.
+
+
+
+
+## Image transforms
+
+Image transforms are data augmentations applied to camera frames during training to improve model robustness and generalization. LeRobot supports various transforms including brightness, contrast, saturation, hue, and sharpness adjustments.
+
+### Using transforms during dataset creation/recording
+
+Currently, transforms are applied during **training time only**, not during recording. When you create or record a dataset, the raw images are stored without transforms. This allows you to experiment with different augmentations later without re-recording data.
+
+### Adding transforms to existing datasets (API)
+
+Use the `image_transforms` parameter when loading a dataset for training:
+
+```python
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig, ImageTransformConfig
+
+# Option 1: Use default transform configuration (disabled by default)
+transforms_config = ImageTransformsConfig(
+ enable=True, # Enable transforms
+ max_num_transforms=3, # Apply up to 3 transforms per frame
+ random_order=False, # Apply in standard order
+)
+transforms = ImageTransforms(transforms_config)
+
+dataset = LeRobotDataset(
+ repo_id="your-username/your-dataset",
+ image_transforms=transforms
+)
+
+# Option 2: Create custom transform configuration
+custom_transforms_config = ImageTransformsConfig(
+ enable=True,
+ max_num_transforms=2,
+ random_order=True,
+ tfs={
+ "brightness": ImageTransformConfig(
+ weight=1.0,
+ type="ColorJitter",
+ kwargs={"brightness": (0.7, 1.3)} # Adjust brightness range
+ ),
+ "contrast": ImageTransformConfig(
+ weight=2.0, # Higher weight = more likely to be selected
+ type="ColorJitter",
+ kwargs={"contrast": (0.8, 1.2)}
+ ),
+ "sharpness": ImageTransformConfig(
+ weight=0.5, # Lower weight = less likely to be selected
+ type="SharpnessJitter",
+ kwargs={"sharpness": (0.3, 2.0)}
+ ),
+ }
+)
+
+dataset = LeRobotDataset(
+ repo_id="your-username/your-dataset",
+ image_transforms=ImageTransforms(custom_transforms_config)
+)
+
+# Option 3: Use pure torchvision transforms
+from torchvision.transforms import v2
+
+torchvision_transforms = v2.Compose([
+ v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
+ v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
+])
+
+dataset = LeRobotDataset(
+ repo_id="your-username/your-dataset",
+ image_transforms=torchvision_transforms
+)
+```
+
+### Available transform types
+
+LeRobot provides several transform types:
+
+- **`ColorJitter`**: Adjusts brightness, contrast, saturation, and hue
+- **`SharpnessJitter`**: Randomly adjusts image sharpness
+- **`Identity`**: No transformation (useful for testing)
+
+You can also use any `torchvision.transforms.v2` transform by passing it directly to the `image_transforms` parameter.
+
+### Configuration options
+
+- **`enable`**: Enable/disable transforms (default: `False`)
+- **`max_num_transforms`**: Maximum number of transforms applied per frame (default: `3`)
+- **`random_order`**: Apply transforms in random order vs. standard order (default: `False`)
+- **`weight`**: Sampling probability for each transform (higher = more likely, if sum of weights is not 1, they will be normalized)
+- **`kwargs`**: Transform-specific parameters (e.g., brightness range)
+
+### Visualizing transforms
+
+Use the visualization script to preview how transforms affect your data:
+
+```bash
+lerobot-imgtransform-viz \
+ --repo-id=your-username/your-dataset \
+ --output-dir=./transform_examples \
+ --n-examples=5
+```
+
+This saves example images showing the effect of each transform, helping you tune parameters.
+
+### Best practices
+
+- **Start conservative**: Begin with small ranges (e.g., brightness 0.9-1.1) and increase gradually
+- **Test first**: Use the visualization script to ensure transforms look reasonable
+- **Monitor training**: Strong augmentations can hurt performance if too aggressive
+- **Match your domain**: If your robot operates in varying lighting, use brightness/contrast transforms
+- **Combine wisely**: Using too many transforms simultaneously can make training unstable
+
+## Migrate `v2.1` → `v3.0`
+
+A converter aggregates per‑episode files into larger shards and writes episode offsets/metadata. Convert your dataset using the instructions below.
+
+```bash
+# Pre-release build with v3 support:
+pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
+
+# Convert an existing v2.1 dataset hosted on the Hub:
+python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=
+```
+
+**What it does**
+
+- Aggregates parquet files: `episode-0000.parquet`, `episode-0001.parquet`, … → **`file-0000.parquet`**, …
+- Aggregates mp4 files: `episode-0000.mp4`, `episode-0001.mp4`, … → **`file-0000.mp4`**, …
+- Updates `meta/episodes/*` (chunked Parquet) with per‑episode lengths, tasks, and byte/frame offsets.
diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx
new file mode 100644
index 000000000..3f2b92406
--- /dev/null
+++ b/docs/source/libero.mdx
@@ -0,0 +1,166 @@
+# LIBERO
+
+**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots won’t just be pretrained once in a factory, they’ll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and it’s a key step toward building robots that become truly personalized helpers.
+
+- 📄 [LIBERO paper](https://arxiv.org/abs/2306.03310)
+- 💻 [Original LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO)
+
+To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other’s work.
+
+LIBERO includes **five task suites**:
+
+- **LIBERO-Spatial (`libero_spatial`)** – tasks that require reasoning about spatial relations.
+- **LIBERO-Object (`libero_object`)** – tasks centered on manipulating different objects.
+- **LIBERO-Goal (`libero_goal`)** – goal-conditioned tasks where the robot must adapt to changing targets.
+- **LIBERO-90 (`libero_90`)** – 90 short-horizon tasks from the LIBERO-100 collection.
+- **LIBERO-Long (`libero_10`)** – 10 long-horizon tasks from the LIBERO-100 collection.
+
+Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms.
+
+
+
+## Evaluating with LIBERO
+
+At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it mainly to **evaluate [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model.
+
+LIBERO is now part of our **multi-eval supported simulation**, meaning you can benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a flag.
+
+To Install LIBERO, after following LeRobot official instructions, just do:
+`pip install -e ".[libero]"`
+
+### Single-suite evaluation
+
+Evaluate a policy on one LIBERO suite:
+
+```bash
+lerobot-eval \
+ --policy.path="your-policy-id" \
+ --env.type=libero \
+ --env.task=libero_object \
+ --eval.batch_size=2 \
+ --eval.n_episodes=3
+```
+
+- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.).
+- `--eval.batch_size` controls how many environments run in parallel.
+- `--eval.n_episodes` sets how many episodes to run in total.
+
+---
+
+### Multi-suite evaluation
+
+Benchmark a policy across multiple suites at once:
+
+```bash
+lerobot-eval \
+ --policy.path="your-policy-id" \
+ --env.type=libero \
+ --env.task=libero_object,libero_spatial \
+ --eval.batch_size=1 \
+ --eval.n_episodes=2
+```
+
+- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
+
+### Policy inputs and outputs
+
+When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
+
+- **Observations**
+ - `observation.state` – proprioceptive features (agent state).
+ - `observation.images.image` – main camera view (`agentview_image`).
+ - `observation.images.image2` – wrist camera view (`robot0_eye_in_hand_image`).
+
+ ⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any multi-modal visual features. Always ensure that your policy config `input_features` use the same naming keys, and that your dataset metadata keys follow this convention during evaluation.
+ If your data contains different keys, you must rename the observations to match what the policy expects, since naming keys are encoded inside the normalization statistics layer.
+ This will be fixed with the upcoming Pipeline PR.
+
+- **Actions**
+ - Continuous control values in a `Box(-1, 1, shape=(7,))` space.
+
+We also provide a notebook for quick testing:
+Training with LIBERO
+
+## Training with LIBERO
+
+When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention.
+
+The environment expects:
+
+- `observation.state` → 8-dim agent state
+- `observation.images.image` → main camera (`agentview_image`)
+- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`)
+
+⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code.
+To avoid potential mismatches and key errors, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulation:
+👉 [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero)
+
+For reference, here is the **original dataset** published by Physical Intelligence:
+👉 [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero)
+
+---
+
+### Example training command
+
+```bash
+lerobot-train \
+ --policy.type=smolvla \
+ --policy.repo_id=${HF_USER}/libero-test \
+ --policy.load_vlm_weights=true \
+ --dataset.repo_id=HuggingFaceVLA/libero \
+ --env.type=libero \
+ --env.task=libero_10 \
+ --output_dir=./outputs/ \
+ --steps=100000 \
+ --batch_size=4 \
+ --eval.batch_size=1 \
+ --eval.n_episodes=1 \
+ --eval_freq=1000 \
+```
+
+---
+
+### Note on rendering
+
+LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
+
+- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
+
+## Reproducing π₀.₅ results
+
+We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
+
+The finetuned model can be found here:
+
+- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
+
+We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
+
+```bash
+python src/lerobot/scripts/eval.py \
+ --output_dir=/logs/ \
+ --env.type=libero \
+ --env.task=libero_spatial,libero_object,libero_goal,libero_10 \
+ --eval.batch_size=1 \
+ --eval.n_episodes=10 \
+ --policy.path=pi05_libero_finetuned \
+ --policy.n_action_steps=10 \
+ --output_dir=./eval_logs/ \
+ --env.max_parallel_tasks=1
+```
+
+**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
+
+### Results
+
+We obtain the following results on the LIBERO benchmark:
+
+| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
+| -------- | -------------- | ------------- | ----------- | --------- | -------- |
+| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
+
+These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
+
+| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
+| -------- | -------------- | ------------- | ----------- | --------- | --------- |
+| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
diff --git a/docs/source/notebooks.mdx b/docs/source/notebooks.mdx
index 729b31a99..6a9c3b103 100644
--- a/docs/source/notebooks.mdx
+++ b/docs/source/notebooks.mdx
@@ -10,8 +10,8 @@ This repository contains example notebooks for using LeRobot. These notebooks de
We provide a ready-to-run Google Colab notebook to help you train ACT policies using datasets from the Hugging Face Hub, with optional logging to Weights & Biases.
-| Notebook | Colab |
-|:---------|:------|
+| Notebook | Colab |
+| :------------------------------------------------------------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [Train ACT with LeRobot](https://github.com/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) |
Expected training time for 100k steps: ~1.5 hours on an NVIDIA A100 GPU with batch size of `64`.
diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx
new file mode 100644
index 000000000..22159193c
--- /dev/null
+++ b/docs/source/phone_teleop.mdx
@@ -0,0 +1,191 @@
+# Phone
+
+Use your phone (iOS or Android) to control your robot.
+
+**In this guide you'll learn:**
+
+- How to connect an iOS/Android phone
+- How phone pose is mapped to robot end‑effector (EE) targets
+- How to tweak safety limits, gripper control, and IK settings
+
+To use phone to control your robot, install the relevant dependencies with:
+
+```bash
+pip install lerobot[phone]
+```
+
+## Get started
+
+### Supported platforms
+
+- iOS: Uses the HEBI Mobile I/O app (ARKit pose + buttons). Download the app first, open it and the examples will discover it on your network and stream the phone pose and inputs.
+- Android: Uses the `teleop` package (WebXR). When you start the Python process, it prints a local URL. Open the link on your phone, tap Start, then use Move to stream pose.
+
+Links:
+
+- Android WebXR library: [`teleop` on PyPI](https://pypi.org/project/teleop/)
+- iOS app: [HEBI Mobile I/O](https://docs.hebi.us/tools.html#mobile-io)
+
+### Phone orientation and controls
+
+- Orientation: hold the phone with the screen facing up and the top edge pointing in the same direction as the robot gripper. This ensures calibration aligns the phone’s frame with the robot frame so motion feels natural, see the image below for reference.
+- Enable/disable:
+ - iOS: Hold `B1` to enable teleoperation, release to stop. The first press captures a reference pose.
+ - Android: Press and hold the `Move` button, release to stop. The first press captures a reference pose.
+- Gripper control:
+ - iOS: Analog input `A3` controls the gripper as velocity input.
+ - Android: Buttons `A` and `B` act like increment/decrement (A opens, B closes). You can tune velocity in the `GripperVelocityToJoint` step.
+
+
+
+### Step 1: Choose the platform
+
+Modify the examples to use `PhoneOS.IOS` or `PhoneOS.ANDROID` in `PhoneConfig`. The API is identical across platforms, only the input source differs. All examples are under `examples/` and have `phone_so100_*.py` variants.
+
+Teleoperation example:
+
+```36:43:examples/phone_so100_teleop.py
+from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
+
+teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
+teleop_device = Phone(teleop_config)
+```
+
+### Step 2: Connect and calibrate
+
+When `Phone(teleop_config)` is created and `connect()` is called, calibration is prompted automatically. Hold the phone in the orientation described above, then:
+
+- iOS: press and hold `B1` to capture the reference pose.
+- Android: press `Move` button on the WebXR page to capture the reference pose.
+
+Why calibrate? We capture the current pose so subsequent poses are expressed in a robot aligned frame. When you again press the button to enable control, the position is recaptured to avoid drift when your phone is repositioned while it was disabled.
+
+### Step 3: Run an example
+
+Run on of the examples scripts to teleoperate, record a dataset, replay a dataset or evaluate a policy.
+
+All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port.
+
+Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf)
+
+- Run this example to teleoperate:
+
+ ```bash
+ python examples/phone_to_so100/teleoperate.py
+ ```
+
+After running the example:
+
+- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move.
+- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper.
+
+Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide.
+
+- Run this example to record a dataset, which saves absolute end effector observations and actions:
+
+ ```bash
+ python examples/phone_to_so100/record.py
+ ```
+
+- Run this example to replay recorded episodes:
+
+ ```bash
+ python examples/phone_to_so100/replay.py
+ ```
+
+- Run this example to evaluate a pretrained policy:
+
+ ```bash
+ python examples/phone_to_so100/evaluate.py
+ ```
+
+### Important pipeline steps and options
+
+- Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame.
+
+ ```examples/phone_to_so100/teleoperate.py
+ kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(robot.bus.motors.keys()),
+ )
+
+ ```
+
+- The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs.
+
+ ```src/lerobot/teleoperators/phone/phone_processor.py
+ action["enabled"] = enabled
+ action["target_x"] = -pos[1] if enabled else 0.0
+ action["target_y"] = pos[0] if enabled else 0.0
+ action["target_z"] = pos[2] if enabled else 0.0
+ action["target_wx"] = rotvec[1] if enabled else 0.0
+ action["target_wy"] = rotvec[0] if enabled else 0.0
+ action["target_wz"] = -rotvec[2] if enabled else 0.0
+ action["gripper_vel"] = gripper_vel # Still send gripper action when disabled
+ ```
+
+- The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed.
+
+ ```examples/phone_to_so100/teleoperate.py
+ EEReferenceAndDelta(
+ kinematics=kinematics_solver,
+ end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
+ motor_names=list(robot.bus.motors.keys()),
+ use_latched_reference=True,
+ ),
+ ```
+
+- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits.
+
+ ```examples/phone_to_so100/teleoperate.py
+ EEBoundsAndSafety(
+ end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
+ max_ee_step_m=0.10,
+ )
+ ```
+
+- The `GripperVelocityToJoint` step turns a velocity‑like gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied.
+
+ ```examples/phone_to_so100/teleoperate.py
+ GripperVelocityToJoint(speed_factor=20.0)
+ ```
+
+#### Different IK initial guesses
+
+We use different IK initial guesses in the kinematic steps. As initial guess either the current measured joints or the previous IK solution is used.
+
+- Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame.
+
+ ```examples/phone_to_so100/record.py
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=True, # closed loop
+ )
+ ```
+
+- Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback.
+
+ ```examples/phone_to_so100/replay.py
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=False, # open loop
+ )
+ ```
+
+### Pipeline steps explained
+
+- MapPhoneActionToRobotAction: converts calibrated phone pose and inputs into target deltas and a gripper command. Motion is gated by an enable signal (B1 on iOS, Move on Android).
+- EEReferenceAndDelta: latches a reference EE pose on enable and combines it with target deltas to produce an absolute desired EE pose each frame. When disabled, it keeps sending the last commanded pose.
+- EEBoundsAndSafety: clamps the EE pose to a workspace and rate‑limits jumps for safety. Also declares `action.ee.*` features.
+- InverseKinematicsEEToJoints: turns an EE pose into joint positions with IK. `initial_guess_current_joints=True` is recommended for closed‑loop control; set `False` for open‑loop replay for stability.
+- GripperVelocityToJoint: integrates a velocity‑like gripper input into an absolute gripper position using the current measured state.
+- ForwardKinematicsJointsToEE: computes `observation.state.ee.*` from observed joints for logging and training on EE state.
+
+### Troubleshooting
+
+- iOS not discovered: ensure HEBI Mobile I/O is open and your laptop/phone are on the same network.
+- Android URL not reachable: check local you used `https` instead of `http`, use the exact IP printed by the script and allow your browser to enter and ignore the certificate issue.
+- Motion feels inverted: adjust the sign flips in `MapPhoneActionToRobotAction` or swap axes to match your setup.
diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx
new file mode 100644
index 000000000..d36fe0ce4
--- /dev/null
+++ b/docs/source/pi0.mdx
@@ -0,0 +1,79 @@
+# π₀ (Pi0)
+
+π₀ is a **Vision-Language-Action model for general robot control**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
+
+## Model Overview
+
+π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
+
+### The Vision for Physical Intelligence
+
+As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
+
+### Architecture and Approach
+
+π₀ combines several key innovations:
+
+- **Flow Matching**: Uses a novel method to augment pre-trained VLMs with continuous action outputs via flow matching (a variant of diffusion models)
+- **Cross-Embodiment Training**: Trained on data from 8 distinct robot platforms including UR5e, Bimanual UR5e, Franka, Bimanual Trossen, Bimanual ARX, Mobile Trossen, and Mobile Fibocom
+- **Internet-Scale Pre-training**: Inherits semantic knowledge from a pre-trained 3B parameter Vision-Language Model
+- **High-Frequency Control**: Outputs motor commands at up to 50 Hz for real-time dexterous manipulation
+
+## Installation Requirements
+
+1. Install LeRobot by following our [Installation Guide](./installation).
+2. Install Pi0 dependencies by running:
+
+ ```bash
+ pip install -e ".[pi]"
+ ```
+
+## Training Data and Capabilities
+
+π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
+
+1. **Internet-Scale Pre-training**: Vision-language data from the web for semantic understanding
+2. **Open X-Embodiment Dataset**: Open-source robot manipulation datasets
+3. **Physical Intelligence Dataset**: Large and diverse dataset of dexterous tasks across 8 distinct robots
+
+## Usage
+
+To use π₀ in LeRobot, specify the policy type as:
+
+```python
+policy.type=pi0
+```
+
+## Training
+
+For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
+
+```bash
+python src/lerobot/scripts/lerobot_train.py \
+ --dataset.repo_id=your_dataset \
+ --policy.type=pi0 \
+ --output_dir=./outputs/pi0_training \
+ --job_name=pi0_training \
+ --policy.pretrained_path=lerobot/pi0_base \
+ --policy.repo_id=your_repo_id \
+ --policy.compile_model=true \
+ --policy.gradient_checkpointing=true \
+ --policy.dtype=bfloat16 \
+ --steps=3000 \
+ --policy.device=cuda \
+ --batch_size=32
+```
+
+### Key Training Parameters
+
+- **`--policy.compile_model=true`**: Enables model compilation for faster training
+- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
+- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
+- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
+- **`--policy.pretrained_path=lerobot/pi0_base`**: The base π₀ model you want to finetune, options are:
+ - [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
+ - [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
+
+## License
+
+This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx
new file mode 100644
index 000000000..b6267fc5e
--- /dev/null
+++ b/docs/source/pi05.mdx
@@ -0,0 +1,107 @@
+# π₀.₅ (Pi05) Policy
+
+π₀.₅ is a **Vision-Language-Action model with open-world generalization**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
+
+## Model Overview
+
+π₀.₅ represents a significant evolution from π₀, developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi05) to address a big challenge in robotics: **open-world generalization**. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
+
+### The Generalization Challenge
+
+As Physical Intelligence explains, the fundamental challenge isn't performing tasks of agility or dexterity, but generalization, the ability to correctly perform tasks in new settings with new objects. Consider a robot cleaning different homes: each home has different objects in different places. Generalization must occur at multiple levels:
+
+- **Physical Level**: Understanding how to pick up a spoon (by the handle) or plate (by the edge), even with unseen objects in cluttered environments
+- **Semantic Level**: Understanding task semantics, where to put clothes and shoes (laundry hamper, not on the bed), and what tools are appropriate for cleaning spills
+- **Environmental Level**: Adapting to "messy" real-world environments like homes, grocery stores, offices, and hospitals
+
+### Co-Training on Heterogeneous Data
+
+The breakthrough innovation in π₀.₅ is **co-training on heterogeneous data sources**. The model learns from:
+
+1. **Multimodal Web Data**: Image captioning, visual question answering, object detection
+2. **Verbal Instructions**: Humans coaching robots through complex tasks step-by-step
+3. **Subtask Commands**: High-level semantic behavior labels (e.g., "pick up the pillow" for an unmade bed)
+4. **Cross-Embodiment Robot Data**: Data from various robot platforms with different capabilities
+5. **Multi-Environment Data**: Static robots deployed across many different homes
+6. **Mobile Manipulation Data**: ~400 hours of mobile robot demonstrations
+
+This diverse training mixture creates a "curriculum" that enables generalization across physical, visual, and semantic levels simultaneously.
+
+## Installation Requirements
+
+1. Install LeRobot by following our [Installation Guide](./installation).
+2. Install Pi0.5 dependencies by running:
+
+ ```bash
+ pip install -e ".[pi]"
+ ```
+
+## Usage
+
+To use π₀.₅ in your LeRobot configuration, specify the policy type as:
+
+```python
+policy.type=pi05
+```
+
+## Training
+
+### Training Command Example
+
+Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
+
+```bash
+python src/lerobot/scripts/lerobot_train.py\
+ --dataset.repo_id=your_dataset \
+ --policy.type=pi05 \
+ --output_dir=./outputs/pi05_training \
+ --job_name=pi05_training \
+ --policy.repo_id=your_repo_id \
+ --policy.pretrained_path=lerobot/pi05_base \
+ --policy.compile_model=true \
+ --policy.gradient_checkpointing=true \
+ --wandb.enable=true \
+ --policy.dtype=bfloat16 \
+ --steps=3000 \
+ --policy.device=cuda \
+ --batch_size=32
+```
+
+### Key Training Parameters
+
+- **`--policy.compile_model=true`**: Enables model compilation for faster training
+- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
+- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
+- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
+- **`--policy.pretrained_path=lerobot/pi05_base`**: The base π₀.₅ model you want to finetune, options are:
+ - [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
+ - [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
+
+If your dataset is not converted with `quantiles`, you can convert it with the following command:
+
+```bash
+python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
+ --repo-id=your_dataset \
+```
+
+Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
+
+## Performance Results
+
+### Libero Benchmark Results
+
+π₀.₅ has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the libero base model for an additional 6k steps on the Libero dataset and compared the results to the OpenPI reference results.
+
+| Benchmark | LeRobot Implementation | OpenPI Reference |
+| ------------------ | ---------------------- | ---------------- |
+| **Libero Spatial** | 97.0% | 98.8% |
+| **Libero Object** | 99.0% | 98.2% |
+| **Libero Goal** | 98.0% | 98.0% |
+| **Libero 10** | 96.0% | 92.4% |
+| **Average** | 97.5% | 96.85% |
+
+These results demonstrate π₀.₅'s strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
+
+## License
+
+This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
diff --git a/docs/source/policy_act_README.md b/docs/source/policy_act_README.md
new file mode 100644
index 000000000..371a9136f
--- /dev/null
+++ b/docs/source/policy_act_README.md
@@ -0,0 +1,14 @@
+## Paper
+
+https://tonyzhaozh.github.io/aloha
+
+## Citation
+
+```bibtex
+@article{zhao2023learning,
+ title={Learning fine-grained bimanual manipulation with low-cost hardware},
+ author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea},
+ journal={arXiv preprint arXiv:2304.13705},
+ year={2023}
+}
+```
diff --git a/docs/source/policy_diffusion_README.md b/docs/source/policy_diffusion_README.md
new file mode 100644
index 000000000..9ec934add
--- /dev/null
+++ b/docs/source/policy_diffusion_README.md
@@ -0,0 +1,14 @@
+## Paper
+
+https://diffusion-policy.cs.columbia.edu
+
+## Citation
+
+```bibtex
+@article{chi2024diffusionpolicy,
+ author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
+ title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
+ journal = {The International Journal of Robotics Research},
+ year = {2024},
+}
+```
diff --git a/docs/source/policy_smolvla_README.md b/docs/source/policy_smolvla_README.md
new file mode 100644
index 000000000..ee567ee83
--- /dev/null
+++ b/docs/source/policy_smolvla_README.md
@@ -0,0 +1,14 @@
+## Paper
+
+https://arxiv.org/abs/2506.01844
+
+## Citation
+
+```bibtex
+@article{shukor2025smolvla,
+ title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics},
+ author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi},
+ journal={arXiv preprint arXiv:2506.01844},
+ year={2025}
+}
+```
diff --git a/docs/source/policy_tdmpc_README.md b/docs/source/policy_tdmpc_README.md
new file mode 100644
index 000000000..804f166c8
--- /dev/null
+++ b/docs/source/policy_tdmpc_README.md
@@ -0,0 +1,14 @@
+## Paper
+
+https://www.nicklashansen.com/td-mpc/
+
+## Citation
+
+```bibtex
+@inproceedings{Hansen2022tdmpc,
+ title={Temporal Difference Learning for Model Predictive Control},
+ author={Nicklas Hansen and Xiaolong Wang and Hao Su},
+ booktitle={ICML},
+ year={2022}
+}
+```
diff --git a/docs/source/policy_vqbet_README.md b/docs/source/policy_vqbet_README.md
new file mode 100644
index 000000000..02f95b7c2
--- /dev/null
+++ b/docs/source/policy_vqbet_README.md
@@ -0,0 +1,14 @@
+## Paper
+
+https://sjlee.cc/vq-bet/
+
+## Citation
+
+```bibtex
+@article{lee2024behavior,
+ title={Behavior generation with latent actions},
+ author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
+ journal={arXiv preprint arXiv:2403.03181},
+ year={2024}
+}
+```
diff --git a/docs/source/porting_datasets_v3.mdx b/docs/source/porting_datasets_v3.mdx
new file mode 100644
index 000000000..46793265e
--- /dev/null
+++ b/docs/source/porting_datasets_v3.mdx
@@ -0,0 +1,321 @@
+# Porting Large Datasets to LeRobot Dataset v3.0
+
+This tutorial explains how to port large-scale robotic datasets to the LeRobot Dataset v3.0 format. We'll use the **DROID 1.0.1** dataset as our primary example, which demonstrates handling multi-terabyte datasets with thousands of shards across SLURM clusters.
+
+## File Organization: v2.1 vs v3.0
+
+Dataset v3.0 fundamentally changes how data is organized and stored:
+
+**v2.1 Structure (Episode-based)**:
+
+```
+dataset/
+├── data/chunk-000/episode_000000.parquet
+├── data/chunk-000/episode_000001.parquet
+├── videos/chunk-000/camera/episode_000000.mp4
+└── meta/episodes.jsonl
+```
+
+**v3.0 Structure (File-based)**:
+
+```
+dataset/
+├── data/chunk-000/file-000.parquet # Multiple episodes per file
+├── videos/camera/chunk-000/file-000.mp4 # Consolidated video chunks
+└── meta/episodes/chunk-000/file-000.parquet # Structured metadata
+```
+
+This transition from individual episode files to file-based chunks dramatically improves performance and reduces storage overhead.
+
+## What's New in Dataset v3.0
+
+Dataset v3.0 introduces significant improvements for handling large datasets:
+
+### 🏗️ **Enhanced File Organization**
+
+- **File-based structure**: Episodes are now grouped into chunked files rather than individual episode files
+- **Configurable file sizes**: for data and video files
+- **Improved storage efficiency**: Better compression and reduced overhead
+
+### 📊 **Modern Metadata Management**
+
+- **Parquet-based metadata**: Replaced JSON Lines with efficient parquet format
+- **Structured episode access**: Direct pandas DataFrame access via `dataset.meta.episodes`
+- **Per-episode statistics**: Enhanced statistics tracking at episode level
+
+### 🚀 **Performance Enhancements**
+
+- **Memory-mapped access**: Improved RAM usage through PyArrow memory mapping
+- **Faster loading**: Significantly reduced dataset initialization time
+- **Better scalability**: Designed for datasets with millions of episodes
+
+## Prerequisites
+
+Before porting large datasets, ensure you have:
+
+- **LeRobot installed** with v3.0 support. Follow our [Installation Guide](./installation).
+- **Sufficient storage**: Raw datasets can be very large (e.g., DROID requires 2TB)
+- **Cluster access** (recommended for large datasets): SLURM or similar job scheduler
+- **Dataset-specific dependencies**: For DROID, you'll need TensorFlow Dataset utilities
+
+## Understanding the DROID Dataset
+
+[DROID 1.0.1](https://droid-dataset.github.io/droid/the-droid-dataset) is an excellent example of a large-scale robotic dataset:
+
+- **Size**: 1.7TB (RLDS format), 8.7TB (raw data)
+- **Structure**: 2048 pre-defined TensorFlow dataset shards
+- **Content**: 76,000+ robot manipulation trajectories from Franka Emika Panda robots
+- **Scope**: Real-world manipulation tasks across multiple environments and objects
+- **Format**: Originally in TensorFlow Records/RLDS format, requiring conversion to LeRobot format
+- **Hosting**: Google Cloud Storage with public access via `gsutil`
+
+The dataset contains diverse manipulation demonstrations with:
+
+- Multiple camera views (wrist camera, exterior cameras)
+- Natural language task descriptions
+- Robot proprioceptive state and actions
+- Success/failure annotations
+
+### DROID Features Schema
+
+```python
+DROID_FEATURES = {
+ # Episode markers
+ "is_first": {"dtype": "bool", "shape": (1,)},
+ "is_last": {"dtype": "bool", "shape": (1,)},
+ "is_terminal": {"dtype": "bool", "shape": (1,)},
+
+ # Language instructions
+ "language_instruction": {"dtype": "string", "shape": (1,)},
+ "language_instruction_2": {"dtype": "string", "shape": (1,)},
+ "language_instruction_3": {"dtype": "string", "shape": (1,)},
+
+ # Robot state
+ "observation.state.gripper_position": {"dtype": "float32", "shape": (1,)},
+ "observation.state.cartesian_position": {"dtype": "float32", "shape": (6,)},
+ "observation.state.joint_position": {"dtype": "float32", "shape": (7,)},
+
+ # Camera observations
+ "observation.images.wrist_left": {"dtype": "image"},
+ "observation.images.exterior_1_left": {"dtype": "image"},
+ "observation.images.exterior_2_left": {"dtype": "image"},
+
+ # Actions
+ "action.gripper_position": {"dtype": "float32", "shape": (1,)},
+ "action.cartesian_position": {"dtype": "float32", "shape": (6,)},
+ "action.joint_position": {"dtype": "float32", "shape": (7,)},
+
+ # Standard LeRobot format
+ "observation.state": {"dtype": "float32", "shape": (8,)}, # joints + gripper
+ "action": {"dtype": "float32", "shape": (8,)}, # joints + gripper
+}
+```
+
+## Approach 1: Single Computer Porting
+
+### Step 1: Install Dependencies
+
+For DROID specifically:
+
+```bash
+pip install tensorflow
+pip install tensorflow_datasets
+```
+
+For other datasets, install the appropriate readers for your source format.
+
+### Step 2: Download Raw Data
+
+Download DROID from Google Cloud Storage using `gsutil`:
+
+```bash
+# Install Google Cloud SDK if not already installed
+# https://cloud.google.com/sdk/docs/install
+
+# Download the full RLDS dataset (1.7TB)
+gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 /your/data/
+
+# Or download just the 100-episode sample (2GB) for testing
+gsutil -m cp -r gs://gresearch/robotics/droid_100 /your/data/
+```
+
+> [!WARNING]
+> Large datasets require substantial time and storage:
+>
+> - **Full DROID (1.7TB)**: Several days to download depending on bandwidth
+> - **Processing time**: 7+ days for local porting of full dataset
+> - **Upload time**: 3+ days to push to Hugging Face Hub
+> - **Local storage**: ~400GB for processed LeRobot format
+
+### Step 3: Port the Dataset
+
+```bash
+python examples/port_datasets/port_droid.py \
+ --raw-dir /your/data/droid/1.0.1 \
+ --repo-id your_id/droid_1.0.1 \
+ --push-to-hub
+```
+
+### Development and Testing
+
+For development, you can port a single shard:
+
+```bash
+python examples/port_datasets/port_droid.py \
+ --raw-dir /your/data/droid/1.0.1 \
+ --repo-id your_id/droid_1.0.1_test \
+ --num-shards 2048 \
+ --shard-index 0
+```
+
+This approach works for smaller datasets or testing, but large datasets require cluster computing.
+
+## Approach 2: SLURM Cluster Porting (Recommended)
+
+For large datasets like DROID, parallel processing across multiple nodes dramatically reduces processing time.
+
+### Step 1: Install Cluster Dependencies
+
+```bash
+pip install datatrove # Hugging Face's distributed processing library
+```
+
+### Step 2: Configure Your SLURM Environment
+
+Find your partition information:
+
+```bash
+sinfo --format="%R" # List available partitions
+sinfo -N -p your_partition -h -o "%N cpus=%c mem=%m" # Check resources
+```
+
+Choose a **CPU partition** - no GPU needed for dataset porting.
+
+### Step 3: Launch Parallel Porting Jobs
+
+```bash
+python examples/port_datasets/slurm_port_shards.py \
+ --raw-dir /your/data/droid/1.0.1 \
+ --repo-id your_id/droid_1.0.1 \
+ --logs-dir /your/logs \
+ --job-name port_droid \
+ --partition your_partition \
+ --workers 2048 \
+ --cpus-per-task 8 \
+ --mem-per-cpu 1950M
+```
+
+#### Parameter Guidelines
+
+- **`--workers`**: Number of parallel jobs (max 2048 for DROID's shard count)
+- **`--cpus-per-task`**: 8 CPUs recommended for frame encoding parallelization
+- **`--mem-per-cpu`**: ~16GB total RAM (8×1950M) for loading raw frames
+
+> [!TIP]
+> Start with fewer workers (e.g., 100) to test your cluster configuration before launching thousands of jobs.
+
+### Step 4: Monitor Progress
+
+Check running jobs:
+
+```bash
+squeue -u $USER
+```
+
+Monitor overall progress:
+
+```bash
+jobs_status /your/logs
+```
+
+Inspect individual job logs:
+
+```bash
+less /your/logs/port_droid/slurm_jobs/JOB_ID_WORKER_ID.out
+```
+
+Debug failed jobs:
+
+```bash
+failed_logs /your/logs/port_droid
+```
+
+### Step 5: Aggregate Shards
+
+Once all porting jobs complete:
+
+```bash
+python examples/port_datasets/slurm_aggregate_shards.py \
+ --repo-id your_id/droid_1.0.1 \
+ --logs-dir /your/logs \
+ --job-name aggr_droid \
+ --partition your_partition \
+ --workers 2048 \
+ --cpus-per-task 8 \
+ --mem-per-cpu 1950M
+```
+
+### Step 6: Upload to Hub
+
+```bash
+python examples/port_datasets/slurm_upload.py \
+ --repo-id your_id/droid_1.0.1 \
+ --logs-dir /your/logs \
+ --job-name upload_droid \
+ --partition your_partition \
+ --workers 50 \
+ --cpus-per-task 4 \
+ --mem-per-cpu 1950M
+```
+
+> [!NOTE]
+> Upload uses fewer workers (50) since it's network-bound rather than compute-bound.
+
+## Dataset v3.0 File Structure
+
+Your completed dataset will have this modern structure:
+
+```
+dataset/
+├── meta/
+│ ├── episodes/
+│ │ └── chunk-000/
+│ │ └── file-000.parquet # Episode metadata
+│ ├── tasks.parquet # Task definitions
+│ ├── stats.json # Aggregated statistics
+│ └── info.json # Dataset information
+├── data/
+│ └── chunk-000/
+│ └── file-000.parquet # Consolidated episode data
+└── videos/
+ └── camera_key/
+ └── chunk-000/
+ └── file-000.mp4 # Consolidated video files
+```
+
+This replaces the old episode-per-file structure with efficient, optimally-sized chunks.
+
+## Migrating from Dataset v2.1
+
+If you have existing datasets in v2.1 format, use the migration tool:
+
+```bash
+python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
+ --repo-id your_id/existing_dataset
+```
+
+This automatically:
+
+- Converts file structure to v3.0 format
+- Migrates metadata from JSON Lines to parquet
+- Aggregates statistics and creates per-episode stats
+- Updates version information
+
+## Performance Benefits
+
+Dataset v3.0 provides significant improvements for large datasets:
+
+- **Faster loading**: 3-5x reduction in initialization time
+- **Memory efficiency**: Better RAM usage through memory mapping
+- **Scalable processing**: Handles millions of episodes efficiently
+- **Storage optimization**: Reduced file count and improved compression
diff --git a/docs/source/processors_robots_teleop.mdx b/docs/source/processors_robots_teleop.mdx
new file mode 100644
index 000000000..3d8dcb409
--- /dev/null
+++ b/docs/source/processors_robots_teleop.mdx
@@ -0,0 +1,151 @@
+# Processors for Robots and Teleoperators
+
+This guide shows how to build and modify processing pipelines that connect teleoperators (e.g., phone) to robots and datasets. Pipelines standardize conversions between different action/observation spaces so you can swap teleops and robots without rewriting glue code.
+
+We use the Phone to SO‑100 follower examples for concreteness, but the same patterns apply to other robots.
+
+**What you'll learn**
+
+- Absolute vs. relative EE control: What each means, trade‑offs, and how to choose for your task.
+- Three-pipeline pattern: How to map teleop actions → dataset actions → robot commands, and robot observations → dataset observations.
+- Adapters (`to_transition` / `to_output`): How these convert raw dicts to `EnvTransition` and back to reduce boilerplate.
+- Dataset feature contracts: How steps declare features via `transform_features(...)`, and how to aggregate/merge them for recording.
+- Choosing a representation: When to store joints, absolute EE poses, or relative EE deltas—and how that affects training.
+- Pipeline customization guidance: How to swap robots/URDFs safely and tune bounds, step sizes, and options like IK initialization.
+
+### Absolute vs relative EE control
+
+The examples in this guide use absolute end effector (EE) poses because they are easy to reason about. In practice, relative EE deltas or joint position are often preferred as learning features.
+
+With processors, you choose the learning features you want to use for your policy. This could be joints positions/velocities, absolute EE, or relative EE positions. You can also choose to store other features, such as joint torques, motor currents, etc.
+
+## Three pipelines
+
+We often compose three pipelines. Depending on your setup, some can be empty if action and observation spaces already match.
+Each of these pipelines handle different conversions between different action and observation spaces. Below is a quick explanation of each pipeline.
+
+1. Pipeline 1: Teleop action space → dataset action space (phone pose → EE targets)
+2. Pipeline 2: Dataset action space → robot command space (EE targets → joints)
+3. Pipeline 3: Robot observation space → dataset observation space (joints → EE pose)
+
+Below is an example of the three pipelines that we use in the phone to SO-100 follower examples:
+
+```69:90:examples/phone_so100_record.py
+phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # teleop -> dataset action
+ steps=[
+ MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
+ EEReferenceAndDelta(
+ kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()),
+ ),
+ EEBoundsAndSafety(
+ end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20,
+ ),
+ GripperVelocityToJoint(),
+ ],
+ to_transition=robot_action_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+robot_ee_to_joints_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # dataset action -> robot
+ steps=[
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()), initial_guess_current_joints=True,
+ ),
+ ],
+ to_transition=robot_action_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( # robot obs -> dataset obs
+ steps=[
+ ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
+ ],
+ to_transition=observation_to_transition,
+ to_output=transition_to_observation,
+)
+```
+
+## Why to_transition / to_output
+
+To convert from robot/teleoperator to pipeline and back, we use the `to_transition` and `to_output` pipeline adapters.
+They standardize conversions to reduce boilerplate code, and form the bridge between the robot and teleoperators raw dictionaries and the pipeline’s `EnvTransition` format.
+In the phone to SO-100 follower examples we use the following adapters:
+
+- `robot_action_to_transition`: transforms the teleop action dict to a pipeline transition.
+- `transition_to_robot_action`: transforms the pipeline transition to a robot action dict.
+- `observation_to_transition`: transforms the robot observation dict to a pipeline transition.
+- `transition_to_observation`: transforms the pipeline transition to a observation dict.
+
+Checkout [src/lerobot/processor/converters.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/processor/converters.py) for more details.
+
+## Dataset feature contracts
+
+Dataset features are determined by the keys saved in the dataset. Each step can declare what features it modifies in a contract called `transform_features(...)`. Once you build a processor, the processor can then aggregate all of these features with `aggregate_pipeline_dataset_features()` and merge multiple feature dicts with `combine_feature_dicts(...)`.
+
+Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples:
+
+```src/lerobot/robots/so100_follower/robot_kinematic_processor.py
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ # We only use the ee pose in the dataset, so we don't need the joint positions
+ for n in self.motor_names:
+ features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None)
+ # We specify the dataset features of this step that we want to be stored in the dataset
+ for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
+ features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature(
+ type=FeatureType.STATE, shape=(1,)
+ )
+ return features
+```
+
+Here we declare what PolicyFeatures we modify in this step, so we know what features we can expect when we run the processor. These features can then be aggregated and used to create the dataset features.
+
+Below is an example of how we aggregate and merge features in the phone to SO-100 record example:
+
+```121:145:examples/phone_so100_record.py
+features=combine_feature_dicts(
+ # Run the feature contract of the pipelines
+ # This tells you how the features would look like after the pipeline steps
+ aggregate_pipeline_dataset_features(
+ pipeline=phone_to_robot_ee_pose_processor,
+ initial_features=create_initial_features(action=phone.action_features), # <- Action features we can expect, these come from our teleop device (phone) and action processor
+ use_videos=True,
+ ),
+ aggregate_pipeline_dataset_features(
+ pipeline=robot_joints_to_ee_pose,
+ initial_features=create_initial_features(observation=robot.observation_features), # <- Observation features we can expect, these come from our robot and observation processor
+ use_videos=True,
+ patterns=["observation.state.ee"], # <- Here you could optionally filter the features we want to store in the dataset, with a specific pattern
+
+ ),
+ ),
+```
+
+How it works:
+
+- `aggregate_pipeline_dataset_features(...)`: applies `transform_features` across the pipeline and filters by patterns (images included when `use_videos=True`, and state features included when `patterns` is specified).
+- `combine_feature_dicts(...)`: combine multiple feature dicts.
+- Recording with `record_loop(...)` uses `build_dataset_frame(...)` to build frames consistent with `dataset.features` before we call `add_frame(...)` to add the frame to the dataset.
+
+## Guidance when customizing robot pipelines
+
+You can store any of the following features as your action/observation space:
+
+- Joint positions
+- Absolute EE poses
+- Relative EE deltas
+- Other features: joint velocity, torques, etc.
+
+Pick what you want to use for your policy action and observation space and configure/modify the pipelines and steps accordingly.
+
+### Different robots
+
+- You can easily reuse pipelines, for example to use another robot with phone teleop, modify the examples and swap the robot `RobotKinematics` (URDF) and `motor_names` to use your own robot with Phone teleop. Additionally you should ensure `target_frame_name` points to your gripper/wrist.
+
+### Safety first
+
+- When changing pipelines, start with tight bounds, implement safety steps when working with real robots.
+- Its advised to start with simulation first and then move to real robots.
+
+Thats it! We hope this guide helps you get started with customizing your robot pipelines, If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support.
diff --git a/docs/source/reachy2.mdx b/docs/source/reachy2.mdx
new file mode 100644
index 000000000..7d3dc1b60
--- /dev/null
+++ b/docs/source/reachy2.mdx
@@ -0,0 +1,288 @@
+# Reachy 2
+
+Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
+Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
+
+## Teleoperate Reachy 2
+
+Currently, there are two ways to teleoperate Reachy 2:
+
+- Pollen Robotics’ VR teleoperation (not included in LeRobot).
+- Robot-to-robot teleoperation (use one Reachy 2 to control another).
+
+## Reachy 2 Simulation
+
+**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
+
+1. Install [Docker Engine](https://docs.docker.com/engine/).
+2. Run (for MuJoCo):
+
+```
+docker run --rm -it \
+ --name reachy \
+ --privileged \
+ --network host \
+ --ipc host \
+ --device-cgroup-rule='c 189:* rwm' \
+ --group-add audio \
+ -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
+ -e DISPLAY="$DISPLAY" \
+ -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
+ -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
+ -v /dev:/dev \
+ -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
+ -v "$HOME/.reachy.log":/home/reachy/.ros/log \
+ -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
+ --entrypoint /package/launch.sh \
+ pollenrobotics/reachy2_core:1.7.5.9_deploy \
+ start_rviz:=true start_sdk_server:=true mujoco:=true
+```
+
+> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
+>
+> ```
+> docker run --rm -it \
+> --name reachy \
+> --privileged \
+> --network host \
+> --ipc host \
+> --device-cgroup-rule='c 189:* rwm' \
+> --group-add audio \
+> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
+> -e DISPLAY="$DISPLAY" \
+> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
+> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
+> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
+> -v /dev:/dev \
+> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
+> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
+> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
+> --entrypoint /package/launch.sh \
+> pollenrobotics/reachy2_core:1.7.5.9_deploy \
+> start_rviz:=true start_sdk_server:=true mujoco:=true
+> ```
+
+## Setup
+
+### Prerequisites
+
+- On your robot, check the **service images** meet the minimum versions:
+ - **reachy2-core >= 1.7.5.2**
+ - **webrtc >= 2.0.1.1**
+
+Then, if you want to use VR teleoperation:
+
+- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
+ Use version **>=v1.2.0**
+
+We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
+
+### Install LeRobot
+
+Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
+
+Install LeRobot with Reachy 2 dependencies:
+
+```bash
+pip install -e ".[reachy2]"
+```
+
+### (Optional but recommended) Install pollen_data_acquisition_server
+
+How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
+
+> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. You’re free to develop custom clients to manage sessions to your needs.
+
+In your LeRobot environment, install the server from source:
+
+```bash
+git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
+cd pollen_data_acquisition_server
+pip install -e .
+```
+
+Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
+
+## Step 1: Recording
+
+### Get Reachy 2 IP address
+
+Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
+We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
+
+### Launch recording
+
+There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
+
+- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robot’s motions.
+- **Using LeRobot’s record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, it’s only for motion control.
+
+### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
+
+Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
+
+Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
+
+```bash
+python -m pollen_data_acquisition_server.server
+```
+
+Then get into the teleoperation application and choose "Data acquisition session".
+You can finally setup your session by following the screens displayed.
+
+> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
+
+### Option 2: Using lerobot.record
+
+Reachy 2 is fully supported by LeRobot’s recording features.
+If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
+
+**Example: start a recording without the mobile base:**
+First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
+
+```bash
+python -m lerobot.record \
+ --robot.type=reachy2 \
+ --robot.ip_address=192.168.0.200 \
+ --robot.id=r2-0000 \
+ --robot.use_external_commands=true \
+ --robot.with_mobile_base=false \
+ --teleop.type=reachy2_teleoperator \
+ --teleop.ip_address=192.168.0.200 \
+ --teleop.with_mobile_base=false \
+ --dataset.repo_id=pollen_robotics/record_test \
+ --dataset.single_task="Reachy 2 recording test" \
+ --dataset.num_episodes=1 \
+ --dataset.episode_time_s=5 \
+ --dataset.fps=15 \
+ --dataset.push_to_hub=true \
+ --dataset.private=true \
+ --display_data=true
+```
+
+#### Specific Options
+
+**Extended setup overview (all options included):**
+
+```bash
+python -m lerobot.record \
+ --robot.type=reachy2 \
+ --robot.ip_address=192.168.0.200 \
+ --robot.use_external_commands=true \
+ --robot.with_mobile_base=true \
+ --robot.with_l_arm=true \
+ --robot.with_r_arm=true \
+ --robot.with_neck=true \
+ --robot.with_antennas=true \
+ --robot.with_left_teleop_camera=true \
+ --robot.with_right_teleop_camera=true \
+ --robot.with_torso_camera=false \
+ --robot.disable_torque_on_disconnect=false \
+ --robot.max_relative_target=5.0 \
+ --teleop.type=reachy2_teleoperator \
+ --teleop.ip_address=192.168.0.200 \
+ --teleop.use_present_position=false \
+ --teleop.with_mobile_base=false \
+ --teleop.with_l_arm=true \
+ --teleop.with_r_arm=true \
+ --teleop.with_neck=true \
+ --teleop.with_antennas=true \
+ --dataset.repo_id=pollen_robotics/record_test \
+ --dataset.single_task="Reachy 2 recording test" \
+ --dataset.num_episodes=1 \
+ --dataset.episode_time_s=5 \
+ --dataset.fps=15 \
+ --dataset.push_to_hub=true \
+ --dataset.private=true \
+ --display_data=true
+```
+
+##### `--robot.use_external_commands`
+
+Determine whether LeRobot robot.send_action() sends commands to the robot.
+**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
+
+##### `--teleop.use_present_position`
+
+Determine whether the teleoperator reads the goal or present position of the robot.
+Must be set to true if a compliant Reachy 2 is used to control another one.
+
+##### Use the relevant parts
+
+From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
+To avoid this, you can exclude specific parts from recording and replay using:
+
+````
+--robot.with_=false
+```,
+with `` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
+It determine whether the corresponding part is recorded in the observations. True if not set.
+
+By default, **all parts are recorded**.
+
+The same per-part mechanism is available in `reachy2_teleoperator` as well.
+
+````
+
+--teleop.with\_
+
+```
+with `` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
+Determine whether the corresponding part is recorded in the actions. True if not set.
+
+> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
+For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
+
+##### Use the relevant cameras
+
+You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
+
+```
+
+--robot.with_left_teleop_camera=
+--robot.with_right_teleop_camera=
+--robot.with_torso_camera=
+
+````
+
+
+## Step 2: Replay
+
+Make sure the robot is configured with the same parts as the dataset:
+
+```bash
+python -m lerobot.replay \
+ --robot.type=reachy2 \
+ --robot.ip_address=192.168.0.200 \
+ --robot.use_external_commands=false \
+ --robot.with_mobile_base=false \
+ --dataset.repo_id=pollen_robotics/record_test \
+ --dataset.episode=0
+ --display_data=true
+````
+
+## Step 3: Train
+
+```bash
+python -m lerobot.scripts.train \
+ --dataset.repo_id=pollen_robotics/record_test \
+ --policy.type=act \
+ --output_dir=outputs/train/reachy2_test \
+ --job_name=reachy2 \
+ --policy.device=mps \
+ --wandb.enable=true \
+ --policy.repo_id=pollen_robotics/record_test_policy
+```
+
+## Step 4: Evaluate
+
+```bash
+python -m lerobot.record \
+ --robot.type=reachy2 \
+ --robot.ip_address=192.168.0.200 \
+ --display_data=false \
+ --dataset.repo_id=pollen_robotics/eval_record_test \
+ --dataset.single_task="Evaluate reachy2 policy" \
+ --dataset.num_episodes=10 \
+ --policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
+```
diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx
index 17a2bdf18..a56298b5e 100644
--- a/docs/source/smolvla.mdx
+++ b/docs/source/smolvla.mdx
@@ -1,11 +1,20 @@
-# Finetune SmolVLA
+# SmolVLA
SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development!
-
-
- Figure 1. SmolVLA takes as input (i) multiple cameras views, (ii) the robot’s current sensorimotor state, and (iii) a natural language instruction, encoded into contextual features used to condition the action expert when generating an action chunk.
+
+
+
+ Figure 1. SmolVLA takes as input (i) multiple cameras views, (ii) the
+ robot’s current sensorimotor state, and (iii) a natural language
+ instruction, encoded into contextual features used to condition the action
+ expert when generating an action chunk.
+
## Set Up Your Environment
@@ -20,7 +29,7 @@ SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed
## Collect a dataset
SmolVLA is a base model, so fine-tuning on your own data is required for optimal performance in your setup.
-We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](https://huggingface.co/docs/lerobot/getting_started_real_world_robot#record-a-dataset)
+We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](./il_robots)
@@ -32,6 +41,7 @@ We recommend checking out the dataset linked below for reference that was used i
In this dataset, we recorded 50 episodes across 5 distinct cube positions. For each position, we collected 10 episodes of pick-and-place interactions. This structure, repeating each variation several times, helped the model generalize better. We tried similar dataset with 25 episodes, and it was not enough leading to a bad performance. So, the data quality and quantity is definitely a key.
After you have your dataset available on the Hub, you are good to go to use our finetuning script to adapt SmolVLA to your application.
+
## Finetune SmolVLA on your data
@@ -44,7 +54,7 @@ If you don't have a gpu device, you can train using our notebook on [.
```bash
-cd lerobot && python -m lerobot.scripts.train \
+cd lerobot && lerobot-train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=${HF_USER}/mydataset \
--batch_size=64 \
@@ -56,29 +66,38 @@ cd lerobot && python -m lerobot.scripts.train \
```
-You can start with a small batch size and increase it incrementally, if the GPU allows it, as long as loading times remain short.
+ You can start with a small batch size and increase it incrementally, if the
+ GPU allows it, as long as loading times remain short.
Fine-tuning is an art. For a complete overview of the options for finetuning, run
```bash
-python -m lerobot.scripts.train --help
+lerobot-train --help
```
-
-
- Figure 2: Comparison of SmolVLA across task variations. From left to right: (1) pick-place cube counting, (2) pick-place cube counting, (3) pick-place cube counting under perturbations, and (4) generalization on pick-and-place of the lego block with real-world SO101.
+
+
+
+ Figure 2: Comparison of SmolVLA across task variations. From left to right:
+ (1) pick-place cube counting, (2) pick-place cube counting, (3) pick-place
+ cube counting under perturbations, and (4) generalization on pick-and-place
+ of the lego block with real-world SO101.
+
-
## Evaluate the finetuned model and run it in real-time
-Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset).
+Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./il_robots).
Once you are logged in, you can run inference in your setup by doing:
```bash
-python -m lerobot.record \
+lerobot-record \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id
diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx
deleted file mode 120000
index 0a71dc307..000000000
--- a/docs/source/so100.mdx
+++ /dev/null
@@ -1 +0,0 @@
-../../src/lerobot/robots/so100_follower/so100.mdx
\ No newline at end of file
diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx
new file mode 100644
index 000000000..3c73ae801
--- /dev/null
+++ b/docs/source/so100.mdx
@@ -0,0 +1,640 @@
+# SO-100
+
+In the steps below, we explain how to assemble the SO-100 robot.
+
+## Source the parts
+
+Follow this [README](https://github.com/TheRobotStudio/SO-ARM100/blob/main/SO100.md). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. And advise if it's your first time printing or if you don't own a 3D printer.
+
+## Install LeRobot 🤗
+
+To install LeRobot, follow our [Installation Guide](./installation)
+
+In addition to these instructions, you need to install the Feetech SDK:
+
+```bash
+pip install -e ".[feetech]"
+```
+
+## Configure the motors
+
+**Note:**
+Unlike the SO-101, the motor connectors are not easily accessible once the arm is assembled, so the configuration step must be done beforehand.
+
+### 1. Find the USB ports associated with each arm
+
+To find the port for each bus servo adapter, run this script:
+
+```bash
+lerobot-find-port
+```
+
+
+
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
+Remove the USB cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/tty.usbmodem575E0032081
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
+
+
+
+
+On Linux, you might need to give access to the USB ports by running:
+
+```bash
+sudo chmod 666 /dev/ttyACM0
+sudo chmod 666 /dev/ttyACM1
+```
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/ttyACM0', '/dev/ttyACM1']
+Remove the usb cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/ttyACM1
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm.
+
+
+
+
+### 2. Set the motors ids and baudrates
+
+Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate.
+
+To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once.
+
+If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match.
+
+#### Follower
+
+Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter.
+
+For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm.
+
+
+
+
+```bash
+lerobot-setup-motors \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
+```
+
+
+
+
+
+```python
+from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig
+
+config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem585A0076841",
+ id="my_awesome_follower_arm",
+)
+follower = SO100Follower(config)
+follower.setup_motors()
+```
+
+
+
+
+
+You should see the following instruction
+
+```
+Connect the controller board to the 'gripper' motor only and press enter.
+```
+
+As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor.
+
+
+Troubleshooting
+
+If you get an error at that point, check your cables and make sure they are plugged in properly:
+
+
+
Power supply
+
USB cable between your computer and the controller board
+
The 3-pin cable from the controller board to the motor
+
+
+If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
+
+
+
+You should then see the following message:
+
+```
+'gripper' motor id set to 6
+```
+
+Followed by the next instruction:
+
+```
+Connect the controller board to the 'wrist_roll' motor only and press enter.
+```
+
+You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one.
+
+Repeat the operation for each motor as instructed.
+
+> [!TIP]
+> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board.
+
+When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm.
+
+#### Leader
+
+Do the same steps for the leader arm.
+
+
+
+```bash
+lerobot-setup-motors \
+ --teleop.type=so100_leader \
+ --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
+```
+
+
+
+
+```python
+from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
+
+config = SO100LeaderConfig(
+ port="/dev/tty.usbmodem585A0076841",
+ id="my_awesome_leader_arm",
+)
+leader = SO100Leader(config)
+leader.setup_motors()
+```
+
+
+
+
+
+## Step-by-Step Assembly Instructions
+
+## Remove the gears of the 6 leader motors
+
+
+Video removing gears
+
+
+
+
+
+
+
+Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
+
+### Clean Parts
+
+Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
+
+### Additional Guidance
+
+
+Video assembling arms
+
+
+
+
+
+
+
+**Note:**
+This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour.
+
+---
+
+### First Motor
+
+**Step 2: Insert Wires**
+
+- Insert two wires into the first motor.
+
+
+
+**Step 3: Install in Base**
+
+- Place the first motor into the base.
+
+
+
+**Step 4: Secure Motor**
+
+- Fasten the motor with 4 screws. Two from the bottom and two from top.
+
+**Step 5: Attach Motor Holder**
+
+- Slide over the first motor holder and fasten it using two screws (one on each side).
+
+
+
+**Step 6: Attach Motor Horns**
+
+- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears.
+
+
+
+
+
+ Video adding motor horn
+
+
+
+
+**Step 7: Attach Shoulder Part**
+
+- Route one wire to the back of the robot and the other to the left or towards you (see photo).
+- Attach the shoulder part.
+
+
+
+**Step 8: Secure Shoulder**
+
+- Tighten the shoulder part with 4 screws on top and 4 on the bottom
+ _(access bottom holes by turning the shoulder)._
+
+---
+
+### Second Motor Assembly
+
+**Step 9: Install Motor 2**
+
+- Slide the second motor in from the top and link the wire from motor 1 to motor 2.
+
+
+
+**Step 10: Attach Shoulder Holder**
+
+- Add the shoulder motor holder.
+- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo).
+- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor.
+
+
+
+
+
+
+
+**Step 11: Secure Motor 2**
+
+- Fasten the second motor with 4 screws.
+
+**Step 12: Attach Motor Horn**
+
+- Attach both motor horns to motor 2, again use the horn screw.
+
+**Step 13: Attach Base**
+
+- Install the base attachment using 2 screws.
+
+
+
+**Step 14: Attach Upper Arm**
+
+- Attach the upper arm with 4 screws on each side.
+
+
+
+---
+
+### Third Motor Assembly
+
+**Step 15: Install Motor 3**
+
+- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws.
+
+**Step 16: Attach Motor Horn**
+
+- Attach both motor horns to motor 3 and secure one again with a horn screw.
+
+
+
+**Step 17: Attach Forearm**
+
+- Connect the forearm to motor 3 using 4 screws on each side.
+
+
+
+---
+
+### Fourth Motor Assembly
+
+**Step 18: Install Motor 4**
+
+- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw.
+
+
+
+
+
+
+**Step 19: Attach Motor Holder 4**
+
+- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo).
+
+
+
+**Step 20: Secure Motor 4 & Attach Horn**
+
+- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw.
+
+
+
+---
+
+### Wrist Assembly
+
+**Step 21: Install Motor 5**
+
+- Insert motor 5 into the wrist holder and secure it with 2 front screws.
+
+
+
+**Step 22: Attach Wrist**
+
+- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper.
+- Secure the wrist to motor 4 using 4 screws on both sides.
+
+
+
+**Step 23: Attach Wrist Horn**
+
+- Install only one motor horn on the wrist motor and secure it with a horn screw.
+
+
+
+---
+
+### Follower Configuration
+
+**Step 24: Attach Gripper**
+
+- Attach the gripper to motor 5.
+
+
+
+**Step 25: Install Gripper Motor**
+
+- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side.
+
+
+
+**Step 26: Attach Gripper Horn & Claw**
+
+- Attach the motor horns and again use a horn screw.
+- Install the gripper claw and secure it with 4 screws on both sides.
+
+
+
+**Step 27: Mount Controller**
+
+- Attach the motor controller to the back of the robot.
+
+
+
+
+
+
+_Assembly complete – proceed to Leader arm assembly._
+
+---
+
+### Leader Configuration
+
+For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors.
+
+**Step 24: Attach Leader Holder**
+
+- Mount the leader holder onto the wrist and secure it with a screw.
+
+
+
+**Step 25: Attach Handle**
+
+- Attach the handle to motor 5 using 4 screws.
+
+
+
+**Step 26: Install Gripper Motor**
+
+- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire.
+
+
+
+**Step 27: Attach Trigger**
+
+- Attach the follower trigger with 4 screws.
+
+
+
+**Step 28: Mount Controller**
+
+- Attach the motor controller to the back of the robot.
+
+
+
+
+
+
+## Calibrate
+
+Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
+The calibration process is very important because it allows a neural network trained on one robot to work on another.
+
+#### Follower
+
+Run the following command or API example to calibrate the follower arm:
+
+
+
+
+```bash
+lerobot-calibrate \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
+ --robot.id=my_awesome_follower_arm # <- Give the robot a unique name
+```
+
+
+
+
+
+```python
+from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower
+
+config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem585A0076891",
+ id="my_awesome_follower_arm",
+)
+
+follower = SO100Follower(config)
+follower.connect(calibrate=False)
+follower.calibrate()
+follower.disconnect()
+```
+
+
+
+
+
+We unified the calibration method for most robots. Thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video)
+
+#### Leader
+
+Do the same steps to calibrate the leader arm, run the following command or API example:
+
+
+
+
+```bash
+lerobot-calibrate \
+ --teleop.type=so100_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
+ --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
+```
+
+
+
+
+
+```python
+from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
+
+config = SO100LeaderConfig(
+ port="/dev/tty.usbmodem58760431551",
+ id="my_awesome_leader_arm",
+)
+
+leader = SO100Leader(config)
+leader.connect(calibrate=False)
+leader.calibrate()
+leader.disconnect()
+```
+
+
+
+
+
+Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
+
+> [!TIP]
+> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx
deleted file mode 120000
index ab6d0ac61..000000000
--- a/docs/source/so101.mdx
+++ /dev/null
@@ -1 +0,0 @@
-../../src/lerobot/robots/so101_follower/so101.mdx
\ No newline at end of file
diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx
new file mode 100644
index 000000000..00ec3eb74
--- /dev/null
+++ b/docs/source/so101.mdx
@@ -0,0 +1,436 @@
+# SO-101
+
+In the steps below, we explain how to assemble our flagship robot, the SO-101.
+
+## Source the parts
+
+Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts.
+And advise if it's your first time printing or if you don't own a 3D printer.
+
+## Install LeRobot 🤗
+
+To install LeRobot, follow our [Installation Guide](./installation)
+
+In addition to these instructions, you need to install the Feetech SDK:
+
+```bash
+pip install -e ".[feetech]"
+```
+
+## Step-by-Step Assembly Instructions
+
+The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however, uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in the table below.
+
+| Leader-Arm Axis | Motor | Gear Ratio |
+| ------------------- | :---: | :--------: |
+| Base / Shoulder Pan | 1 | 1 / 191 |
+| Shoulder Lift | 2 | 1 / 345 |
+| Elbow Flex | 3 | 1 / 191 |
+| Wrist Flex | 4 | 1 / 147 |
+| Wrist Roll | 5 | 1 / 147 |
+| Gripper | 6 | 1 / 147 |
+
+### Clean Parts
+
+Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
+
+It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
+
+### Joint 1
+
+- Place the first motor into the base.
+- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
+- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
+- Install both motor horns, securing the top horn with a M3x6mm screw.
+- Attach the shoulder part.
+- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
+- Add the shoulder motor holder.
+
+
+
+
+
+### Joint 2
+
+- Slide the second motor in from the top.
+- Fasten the second motor with 4 M2x6mm screws.
+- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
+- Attach the upper arm with 4 M3x6mm screws on each side.
+
+
+
+
+
+### Joint 3
+
+- Insert motor 3 and fasten using 4 M2x6mm screws
+- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
+- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
+
+
+
+
+
+### Joint 4
+
+- Slide over motor holder 4.
+- Slide in motor 4.
+- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
+
+
+
+
+
+### Joint 5
+
+- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
+- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
+- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
+
+
+
+
+
+### Gripper / Handle
+
+
+
+
+- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
+- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
+- Attach the motor horns and again use a M3x6mm horn screw.
+- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
+
+
+
+
+
+
+
+
+- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
+- Attach the handle to motor 5 using 1 M2x6mm screw.
+- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
+- Attach the follower trigger with 4 M3x6mm screws.
+
+
+
+
+
+
+
+
+## Configure the motors
+
+### 1. Find the USB ports associated with each arm
+
+To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted:
+
+```bash
+lerobot-find-port
+```
+
+
+
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
+Remove the USB cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/tty.usbmodem575E0032081
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
+
+
+
+
+On Linux, you might need to give access to the USB ports by running:
+
+```bash
+sudo chmod 666 /dev/ttyACM0
+sudo chmod 666 /dev/ttyACM1
+```
+
+Example output:
+
+```
+Finding all available ports for the MotorBus.
+['/dev/ttyACM0', '/dev/ttyACM1']
+Remove the usb cable from your MotorsBus and press Enter when done.
+
+[...Disconnect corresponding leader or follower arm and press Enter...]
+
+The port of this MotorsBus is /dev/ttyACM1
+Reconnect the USB cable.
+```
+
+Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm.
+
+
+
+
+### 2. Set the motors ids and baudrates
+
+Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate.
+
+To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once.
+
+If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match.
+
+The video below shows the sequence of steps for setting the motor ids.
+
+##### Setup motors video
+
+
+
+
+
+#### Follower
+
+Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter.
+
+
+
+
+```bash
+lerobot-setup-motors \
+ --robot.type=so101_follower \
+ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
+```
+
+
+
+
+
+```python
+from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
+
+config = SO101FollowerConfig(
+ port="/dev/tty.usbmodem585A0076841",
+ id="my_awesome_follower_arm",
+)
+follower = SO101Follower(config)
+follower.setup_motors()
+```
+
+
+
+
+
+You should see the following instruction
+
+```bash
+Connect the controller board to the 'gripper' motor only and press enter.
+```
+
+As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor.
+
+
+Troubleshooting
+
+If you get an error at that point, check your cables and make sure they are plugged in properly:
+
+
+
Power supply
+
USB cable between your computer and the controller board
+
The 3-pin cable from the controller board to the motor
+
+
+If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
+
+
+
+You should then see the following message:
+
+```bash
+'gripper' motor id set to 6
+```
+
+Followed by the next instruction:
+
+```bash
+Connect the controller board to the 'wrist_roll' motor only and press enter.
+```
+
+You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one.
+
+Repeat the operation for each motor as instructed.
+
+> [!TIP]
+> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board.
+
+When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm.
+
+#### Leader
+
+Do the same steps for the leader arm.
+
+
+
+
+```bash
+lerobot-setup-motors \
+ --teleop.type=so101_leader \
+ --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
+```
+
+
+
+
+
+```python
+from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig
+
+config = SO101LeaderConfig(
+ port="/dev/tty.usbmodem585A0076841",
+ id="my_awesome_leader_arm",
+)
+leader = SO101Leader(config)
+leader.setup_motors()
+```
+
+
+
+
+
+## Calibrate
+
+Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
+The calibration process is very important because it allows a neural network trained on one robot to work on another.
+
+#### Follower
+
+Run the following command or API example to calibrate the follower arm:
+
+
+
+
+```bash
+lerobot-calibrate \
+ --robot.type=so101_follower \
+ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
+ --robot.id=my_awesome_follower_arm # <- Give the robot a unique name
+```
+
+
+
+
+
+```python
+from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
+
+config = SO101FollowerConfig(
+ port="/dev/tty.usbmodem585A0076891",
+ id="my_awesome_follower_arm",
+)
+
+follower = SO101Follower(config)
+follower.connect(calibrate=False)
+follower.calibrate()
+follower.disconnect()
+```
+
+
+
+
+
+The video below shows how to perform the calibration. First you need to move the robot to the position where all joints are in the middle of their ranges. Then after pressing enter you have to move each joint through its full range of motion.
+
+##### Calibration video
+
+
+
+
+
+#### Leader
+
+Do the same steps to calibrate the leader arm, run the following command or API example:
+
+
+
+
+```bash
+lerobot-calibrate \
+ --teleop.type=so101_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
+ --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
+```
+
+
+
+
+
+```python
+from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader
+
+config = SO101LeaderConfig(
+ port="/dev/tty.usbmodem58760431551",
+ id="my_awesome_leader_arm",
+)
+
+leader = SO101Leader(config)
+leader.connect(calibrate=False)
+leader.calibrate()
+leader.disconnect()
+```
+
+
+
+
+
+Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
+
+> [!TIP]
+> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py
deleted file mode 100644
index c0c7845e8..000000000
--- a/examples/2_evaluate_pretrained_policy.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
-training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
-
-It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
-```bash
-pip install -e ".[pusht]"
-```
-"""
-
-from pathlib import Path
-
-import gym_pusht # noqa: F401
-import gymnasium as gym
-import imageio
-import numpy
-import torch
-
-from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
-
-# Create a directory to store the video of the evaluation
-output_directory = Path("outputs/eval/example_pusht_diffusion")
-output_directory.mkdir(parents=True, exist_ok=True)
-
-# Select your device
-device = "cuda"
-
-# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
-pretrained_policy_path = "lerobot/diffusion_pusht"
-# OR a path to a local outputs/train folder.
-# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
-
-policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
-
-# Initialize evaluation environment to render two observation types:
-# an image of the scene and state/position of the agent. The environment
-# also automatically stops running after 300 interactions/steps.
-env = gym.make(
- "gym_pusht/PushT-v0",
- obs_type="pixels_agent_pos",
- max_episode_steps=300,
-)
-
-# We can verify that the shapes of the features expected by the policy match the ones from the observations
-# produced by the environment
-print(policy.config.input_features)
-print(env.observation_space)
-
-# Similarly, we can check that the actions produced by the policy will match the actions expected by the
-# environment
-print(policy.config.output_features)
-print(env.action_space)
-
-# Reset the policy and environments to prepare for rollout
-policy.reset()
-numpy_observation, info = env.reset(seed=42)
-
-# Prepare to collect every rewards and all the frames of the episode,
-# from initial state to final state.
-rewards = []
-frames = []
-
-# Render frame of the initial state
-frames.append(env.render())
-
-step = 0
-done = False
-while not done:
- # Prepare observation for the policy running in Pytorch
- state = torch.from_numpy(numpy_observation["agent_pos"])
- image = torch.from_numpy(numpy_observation["pixels"])
-
- # Convert to float32 with image from channel first in [0,255]
- # to channel last in [0,1]
- state = state.to(torch.float32)
- image = image.to(torch.float32) / 255
- image = image.permute(2, 0, 1)
-
- # Send data tensors from CPU to GPU
- state = state.to(device, non_blocking=True)
- image = image.to(device, non_blocking=True)
-
- # Add extra (empty) batch dimension, required to forward the policy
- state = state.unsqueeze(0)
- image = image.unsqueeze(0)
-
- # Create the policy input dictionary
- observation = {
- "observation.state": state,
- "observation.image": image,
- }
-
- # Predict the next action with respect to the current observation
- with torch.inference_mode():
- action = policy.select_action(observation)
-
- # Prepare the action for the environment
- numpy_action = action.squeeze(0).to("cpu").numpy()
-
- # Step through the environment and receive a new observation
- numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
- print(f"{step=} {reward=} {terminated=}")
-
- # Keep track of all the rewards and frames
- rewards.append(reward)
- frames.append(env.render())
-
- # The rollout is considered done when the success state is reached (i.e. terminated is True),
- # or the maximum number of iterations is reached (i.e. truncated is True)
- done = terminated | truncated | done
- step += 1
-
-if terminated:
- print("Success!")
-else:
- print("Failure!")
-
-# Get the speed of environment (i.e. its number of frames per second).
-fps = env.metadata["render_fps"]
-
-# Encode all frames into a mp4 video.
-video_path = output_directory / "rollout.mp4"
-imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
-
-print(f"Video of the evaluation is available in '{video_path}'.")
diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md
deleted file mode 100644
index f17411b75..000000000
--- a/examples/4_train_policy_with_script.md
+++ /dev/null
@@ -1,274 +0,0 @@
-This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
-> **Note:** The following assumes you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
-
-
-## The training script
-
-LeRobot offers a training script at [`lerobot/scripts/train.py`](../src/lerobot/scripts/train.py). At a high level it does the following:
-
-- Initialize/load a configuration for the following steps using.
-- Instantiates a dataset.
-- (Optional) Instantiates a simulation environment corresponding to that dataset.
-- Instantiates a policy.
-- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing.
-
-## Overview of the configuration system
-
-In the training script, the main function `train` expects a `TrainPipelineConfig` object:
-```python
-# train.py
-@parser.wrap()
-def train(cfg: TrainPipelineConfig):
-```
-
-You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../src/lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
-
-When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated to this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
-
-Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes:
-```python
-@dataclass
-class TrainPipelineConfig:
- dataset: DatasetConfig
- env: envs.EnvConfig | None = None
- policy: PreTrainedConfig | None = None
-```
-in which `DatasetConfig` for example is defined as such:
-```python
-@dataclass
-class DatasetConfig:
- repo_id: str
- episodes: list[int] | None = None
- video_backend: str = "pyav"
-```
-
-This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`.
-From the command line, we can specify this value by using a very similar syntax `--dataset.repo_id=repo/id`.
-
-By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file – which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified.
-
-
-## Specifying values from the CLI
-
-Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
-```bash
-python -m lerobot.scripts.train \
- --dataset.repo_id=lerobot/pusht \
- --policy.type=diffusion \
- --env.type=pusht
-```
-
-Let's break this down:
-- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
-- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/policies](../src/lerobot/policies)
-- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/envs/configs.py`](../src/lerobot/envs/configs.py)
-
-Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
-```bash
-python -m lerobot.scripts.train \
- --policy.type=act \
- --dataset.repo_id=lerobot/aloha_sim_insertion_human \
- --env.type=aloha \
- --output_dir=outputs/train/act_aloha_insertion
-```
-> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
-
-We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
-Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
-```bash
-python -m lerobot.scripts.train \
- --policy.type=act \
- --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
- --env.type=aloha \
- --env.task=AlohaTransferCube-v0 \
- --output_dir=outputs/train/act_aloha_transfer
-```
-
-## Loading from a config file
-
-Now, let's assume that we want to reproduce the run just above. That run has produced a `train_config.json` file in its checkpoints, which serializes the `TrainPipelineConfig` instance it used:
-```json
-{
- "dataset": {
- "repo_id": "lerobot/aloha_sim_transfer_cube_human",
- "episodes": null,
- ...
- },
- "env": {
- "type": "aloha",
- "task": "AlohaTransferCube-v0",
- "fps": 50,
- ...
- },
- "policy": {
- "type": "act",
- "n_obs_steps": 1,
- ...
- },
- ...
-}
-```
-
-We can then simply load the config values from this file using:
-```bash
-python -m lerobot.scripts.train \
- --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
- --output_dir=outputs/train/act_aloha_transfer_2
-```
-`--config_path` is also a special argument which allows to initialize the config from a local config file. It can point to a directory that contains `train_config.json` or to the config file itself directly.
-
-Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.:
-```bash
-python -m lerobot.scripts.train \
- --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \
- --output_dir=outputs/train/act_aloha_transfer_2
- --policy.n_action_steps=80
-```
-> Note: While `--output_dir` is not required in general, in this case we need to specify it since it will otherwise take the value from the `train_config.json` (which is `outputs/train/act_aloha_transfer`). In order to prevent accidental deletion of previous run checkpoints, we raise an error if you're trying to write in an existing directory. This is not the case when resuming a run, which is what you'll learn next.
-
-`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running:
-```bash
-python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht
-```
-will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)
-
-
-## Resume training
-
-Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to do that here.
-
-Let's reuse the command from the previous run and add a few more options:
-```bash
-python -m lerobot.scripts.train \
- --policy.type=act \
- --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
- --env.type=aloha \
- --env.task=AlohaTransferCube-v0 \
- --log_freq=25 \
- --save_freq=100 \
- --output_dir=outputs/train/run_resumption
-```
-
-Here we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can showcase resumption. You should be able to see some logging and have a first checkpoint within 1 minute (depending on hardware). Wait for the first checkpoint to happen, you should see a line that looks like this in your terminal:
-```
-INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100
-```
-Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with:
-```bash
-python -m lerobot.scripts.train \
- --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
- --resume=true
-```
-You should see from the logging that your training picks up from where it left off.
-
-Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
-You could double the number of steps of the previous run with:
-```bash
-python -m lerobot.scripts.train \
- --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
- --resume=true \
- --steps=200000
-```
-
-## Outputs of a run
-In the output directory, there will be a folder called `checkpoints` with the following structure:
-```bash
-outputs/train/run_resumption/checkpoints
-├── 000100 # checkpoint_dir for training step 100
-│ ├── pretrained_model/
-│ │ ├── config.json # policy config
-│ │ ├── model.safetensors # policy weights
-│ │ └── train_config.json # train config
-│ └── training_state/
-│ ├── optimizer_param_groups.json # optimizer param groups
-│ ├── optimizer_state.safetensors # optimizer state
-│ ├── rng_state.safetensors # rng states
-│ ├── scheduler_state.json # scheduler state
-│ └── training_step.json # training step
-├── 000200
-└── last -> 000200 # symlink to the last available checkpoint
-```
-
-## Fine-tuning a pre-trained policy
-
-In addition to the features currently in Draccus, we've added a special `.path` argument for the policy, which allows to load a policy as you would with `PreTrainedPolicy.from_pretrained()`. In that case, `path` can be a local directory that contains a checkpoint or a repo_id pointing to a pretrained policy on the hub.
-
-For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with:
-```bash
-python -m lerobot.scripts.train \
- --policy.path=lerobot/act_aloha_sim_transfer_cube_human \
- --dataset.repo_id=lerobot/aloha_sim_insertion_human \
- --env.type=aloha \
- --env.task=AlohaInsertion-v0
-```
-
-When doing so, keep in mind that the features of the fine-tuning dataset would have to match the input/output features of the pretrained policy.
-
-## Typical logs and metrics
-
-When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you configured your run correctly. The final configuration will also be saved with the checkpoint.
-
-After that, you will see training log like this one:
-```
-INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
-```
-or evaluation log:
-```
-INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
-```
-
-These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:
-- `smpl`: number of samples seen during training.
-- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
-- `epch`: number of time all unique samples are seen (epoch).
-- `grdn`: gradient norm.
-- `∑rwrd`: compute the sum of rewards in every evaluation episode and then take an average of them.
-- `success`: average success rate of eval episodes. Reward and success are usually different except for the sparsing reward setting, where reward=1 only when the task is completed successfully.
-- `eval_s`: time to evaluate the policy in the environment, in second.
-- `updt_s`: time to update the network parameters, in second.
-- `data_s`: time to load a batch of data, in second.
-
-Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.
-
-## In short
-
-We'll summarize here the main use cases to remember from this tutorial.
-
-#### Train a policy from scratch – CLI
-```bash
-python -m lerobot.scripts.train \
- --policy.type=act \ # <- select 'act' policy
- --env.type=pusht \ # <- select 'pusht' environment
- --dataset.repo_id=lerobot/pusht # <- train on this dataset
-```
-
-#### Train a policy from scratch - config file + CLI
-```bash
-python -m lerobot.scripts.train \
- --config_path=path/to/pretrained_model \ # <- can also be a repo_id
- --policy.n_action_steps=80 # <- you may still override values
-```
-
-#### Resume/continue a training run
-```bash
-python -m lerobot.scripts.train \
- --config_path=checkpoint/pretrained_model/ \
- --resume=true \
- --steps=200000 # <- you can change some training parameters
-```
-
-#### Fine-tuning
-```bash
-python -m lerobot.scripts.train \
- --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint
- --dataset.repo_id=lerobot/aloha_sim_insertion_human \
- --env.type=aloha \
- --env.task=AlohaInsertion-v0
-```
-
----
-
-Now that you know the basics of how to train a policy, you might want to know how to apply this knowledge to actual robots, or how to record your own datasets and train policies on your specific task?
-If that's the case, head over to the next tutorial [`7_get_started_with_real_robot.md`](./7_get_started_with_real_robot.md).
-
-Or in the meantime, happy training! 🤗
diff --git a/examples/advanced/1_add_image_transforms.py b/examples/advanced/1_add_image_transforms.py
deleted file mode 100644
index 3760feabb..000000000
--- a/examples/advanced/1_add_image_transforms.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
-augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
-transforms are applied to the observation images before they are returned in the dataset's __getitem__.
-"""
-
-from pathlib import Path
-
-from torchvision.transforms import ToPILImage, v2
-
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-
-dataset_repo_id = "lerobot/aloha_static_screw_driver"
-
-# Create a LeRobotDataset with no transformations
-dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
-# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
-
-# Get the index of the first observation in the first episode
-first_idx = dataset.episode_data_index["from"][0].item()
-
-# Get the frame corresponding to the first camera
-frame = dataset[first_idx][dataset.meta.camera_keys[0]]
-
-
-# Define the transformations
-transforms = v2.Compose(
- [
- v2.ColorJitter(brightness=(0.5, 1.5)),
- v2.ColorJitter(contrast=(0.5, 1.5)),
- v2.ColorJitter(hue=(-0.1, 0.1)),
- v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
- ]
-)
-
-# Create another LeRobotDataset with the defined transformations
-transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
-
-# Get a frame from the transformed dataset
-transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
-
-# Create a directory to store output images
-output_dir = Path("outputs/image_transforms")
-output_dir.mkdir(parents=True, exist_ok=True)
-
-# Save the original frame
-to_pil = ToPILImage()
-to_pil(frame).save(output_dir / "original_frame.png", quality=100)
-print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
-
-# Save the transformed frame
-to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
-print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py
deleted file mode 100644
index 9eeb1a2d9..000000000
--- a/examples/advanced/2_calculate_validation_loss.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data.
-
-This technique can be useful for debugging and testing purposes, as well as identifying whether a policy
-is learning effectively.
-
-Furthermore, relying on validation loss to evaluate performance is generally not considered a good practice,
-especially in the context of imitation learning. The most reliable approach is to evaluate the policy directly
-on the target environment, whether that be in simulation or the real world.
-"""
-
-import math
-
-import torch
-
-from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
-from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
-
-
-def main():
- device = torch.device("cuda")
-
- # Download the diffusion policy for pusht environment
- pretrained_policy_path = "lerobot/diffusion_pusht"
- # OR uncomment the following to evaluate a policy from the local outputs/train folder.
- # pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
-
- policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
- policy.eval()
- policy.to(device)
-
- # Set up the dataset.
- delta_timestamps = {
- # Load the previous image and state at -0.1 seconds before current frame,
- # then load current image and state corresponding to 0.0 second.
- "observation.image": [-0.1, 0.0],
- "observation.state": [-0.1, 0.0],
- # Load the previous action (-0.1), the next action to be executed (0.0),
- # and 14 future actions with a 0.1 seconds spacing. All these actions will be
- # used to calculate the loss.
- "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
- }
-
- # Load the last 10% of episodes of the dataset as a validation set.
- # - Load dataset metadata
- dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
- # - Calculate train and val episodes
- total_episodes = dataset_metadata.total_episodes
- episodes = list(range(dataset_metadata.total_episodes))
- num_train_episodes = math.floor(total_episodes * 90 / 100)
- train_episodes = episodes[:num_train_episodes]
- val_episodes = episodes[num_train_episodes:]
- print(f"Number of episodes in full dataset: {total_episodes}")
- print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
- print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
- # - Load train and val datasets
- train_dataset = LeRobotDataset(
- "lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
- )
- val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
- print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
- print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
-
- # Create dataloader for evaluation.
- val_dataloader = torch.utils.data.DataLoader(
- val_dataset,
- num_workers=4,
- batch_size=64,
- shuffle=False,
- pin_memory=device != torch.device("cpu"),
- drop_last=False,
- )
-
- # Run validation loop.
- loss_cumsum = 0
- n_examples_evaluated = 0
- for batch in val_dataloader:
- batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
- loss, _ = policy.forward(batch)
-
- loss_cumsum += loss.item()
- n_examples_evaluated += batch["index"].shape[0]
-
- # Calculate the average loss over the validation set.
- average_loss = loss_cumsum / n_examples_evaluated
-
- print(f"Average loss on validation set: {average_loss:.4f}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py
index cc3397543..6bca0570f 100644
--- a/examples/backward_compatibility/replay.py
+++ b/examples/backward_compatibility/replay.py
@@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot.
Example:
```shell
-python -m lerobot.replay \
+lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
@@ -44,6 +44,7 @@ from lerobot.robots import ( # noqa: F401
so100_follower,
so101_follower,
)
+from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
@@ -78,16 +79,16 @@ def replay(cfg: ReplayConfig):
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
- actions = dataset.hf_dataset.select_columns("action")
+ actions = dataset.hf_dataset.select_columns(ACTION)
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
- action_array = actions[idx]["action"]
+ action_array = actions[idx][ACTION]
action = {}
- for i, name in enumerate(dataset.features["action"]["names"]):
+ for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
diff --git a/examples/1_load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py
similarity index 96%
rename from examples/1_load_lerobot_dataset.py
rename to examples/dataset/load_lerobot_dataset.py
index 3d357dd19..a96c170cf 100644
--- a/examples/1_load_lerobot_dataset.py
+++ b/examples/dataset/load_lerobot_dataset.py
@@ -92,11 +92,11 @@ print(dataset.hf_dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
-# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
+# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
# frame indices associated to the first episode:
episode_index = 0
-from_idx = dataset.episode_data_index["from"][episode_index].item()
-to_idx = dataset.episode_data_index["to"][episode_index].item()
+from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
+to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]
@@ -136,7 +136,7 @@ print(f"{dataset[0]['action'].shape=}\n") # (64, c)
# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
- num_workers=0,
+ num_workers=4,
batch_size=32,
shuffle=True,
)
diff --git a/examples/dataset/use_dataset_image_transforms.py b/examples/dataset/use_dataset_image_transforms.py
new file mode 100644
index 000000000..c28f2ef0c
--- /dev/null
+++ b/examples/dataset/use_dataset_image_transforms.py
@@ -0,0 +1,177 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This example demonstrates how to use image transforms with LeRobot datasets for data augmentation during training.
+
+Image transforms are applied to camera frames to improve model robustness and generalization. They are applied
+at training time only, not during dataset recording, allowing you to experiment with different augmentations
+without re-recording data.
+"""
+
+import torch
+from torchvision.transforms import v2
+from torchvision.transforms.functional import to_pil_image
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets.transforms import ImageTransformConfig, ImageTransforms, ImageTransformsConfig
+
+
+def save_image(tensor, filename):
+ """Helper function to save a tensor as an image file."""
+ if tensor.dim() == 3: # [C, H, W]
+ if tensor.max() > 1.0:
+ tensor = tensor / 255.0
+ tensor = torch.clamp(tensor, 0.0, 1.0)
+ pil_image = to_pil_image(tensor)
+ pil_image.save(filename)
+ print(f"Saved: {filename}")
+ else:
+ print(f"Skipped {filename}: unexpected tensor shape {tensor.shape}")
+
+
+def example_1_default_transforms():
+ """Example 1: Use default transform configuration and save original vs transformed images"""
+ print("\n Example 1: Default Transform Configuration with Image Saving")
+
+ repo_id = "pepijn223/record_main_0" # Example dataset
+
+ try:
+ # Load dataset without transforms (original)
+ dataset_original = LeRobotDataset(repo_id=repo_id)
+
+ # Load dataset with transforms enabled
+ transforms_config = ImageTransformsConfig(
+ enable=True, # Enable transforms (disabled by default)
+ max_num_transforms=2, # Apply up to 2 transforms per frame
+ random_order=False, # Apply in standard order
+ )
+ dataset_with_transforms = LeRobotDataset(
+ repo_id=repo_id, image_transforms=ImageTransforms(transforms_config)
+ )
+
+ # Save original and transformed images for comparison
+ if len(dataset_original) > 0:
+ frame_idx = 0 # Use first frame
+ original_sample = dataset_original[frame_idx]
+ transformed_sample = dataset_with_transforms[frame_idx]
+
+ print(f"Saving comparison images (frame {frame_idx}):")
+
+ for cam_key in dataset_original.meta.camera_keys:
+ if cam_key in original_sample and cam_key in transformed_sample:
+ cam_name = cam_key.replace(".", "_").replace("/", "_")
+
+ # Save original and transformed images
+ save_image(original_sample[cam_key], f"{cam_name}_original.png")
+ save_image(transformed_sample[cam_key], f"{cam_name}_transformed.png")
+
+ except Exception as e:
+ print(f"Could not load dataset '{repo_id}': {e}")
+
+
+def example_2_custom_transforms():
+ """Example 2: Create custom transform configuration and save examples"""
+ print("\n Example 2: Custom Transform Configuration")
+
+ repo_id = "pepijn223/record_main_0" # Example dataset
+
+ try:
+ # Create custom transform configuration with strong effects
+ custom_transforms_config = ImageTransformsConfig(
+ enable=True,
+ max_num_transforms=2, # Apply up to 2 transforms per frame
+ random_order=True, # Apply transforms in random order
+ tfs={
+ "brightness": ImageTransformConfig(
+ weight=1.0,
+ type="ColorJitter",
+ kwargs={"brightness": (0.5, 1.5)}, # Strong brightness range
+ ),
+ "contrast": ImageTransformConfig(
+ weight=1.0, # Higher weight = more likely to be selected
+ type="ColorJitter",
+ kwargs={"contrast": (0.6, 1.4)}, # Strong contrast
+ ),
+ "sharpness": ImageTransformConfig(
+ weight=0.5, # Lower weight = less likely to be selected
+ type="SharpnessJitter",
+ kwargs={"sharpness": (0.2, 2.0)}, # Strong sharpness variation
+ ),
+ },
+ )
+
+ dataset_with_custom_transforms = LeRobotDataset(
+ repo_id=repo_id, image_transforms=ImageTransforms(custom_transforms_config)
+ )
+
+ # Save examples with strong transforms
+ if len(dataset_with_custom_transforms) > 0:
+ sample = dataset_with_custom_transforms[0]
+ print("Saving custom transform examples:")
+
+ for cam_key in dataset_with_custom_transforms.meta.camera_keys:
+ if cam_key in sample:
+ cam_name = cam_key.replace(".", "_").replace("/", "_")
+ save_image(sample[cam_key], f"{cam_name}_custom_transforms.png")
+
+ except Exception as e:
+ print(f"Could not load dataset '{repo_id}': {e}")
+
+
+def example_3_torchvision_transforms():
+ """Example 3: Use pure torchvision transforms and save examples"""
+ print("\n Example 3: Pure Torchvision Transforms")
+
+ repo_id = "pepijn223/record_main_0" # Example dataset
+
+ try:
+ # Create torchvision transform pipeline
+ torchvision_transforms = v2.Compose(
+ [
+ v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
+ v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
+ v2.RandomRotation(degrees=10), # Small rotation
+ ]
+ )
+
+ dataset_with_torchvision = LeRobotDataset(repo_id=repo_id, image_transforms=torchvision_transforms)
+
+ # Save examples with torchvision transforms
+ if len(dataset_with_torchvision) > 0:
+ sample = dataset_with_torchvision[0]
+ print("Saving torchvision transform examples:")
+
+ for cam_key in dataset_with_torchvision.meta.camera_keys:
+ if cam_key in sample:
+ cam_name = cam_key.replace(".", "_").replace("/", "_")
+ save_image(sample[cam_key], f"{cam_name}_torchvision.png")
+
+ except Exception as e:
+ print(f"Could not load dataset '{repo_id}': {e}")
+
+
+def main():
+ """Run all examples"""
+ print("LeRobot Dataset Image Transforms Examples")
+
+ example_1_default_transforms()
+ example_2_custom_transforms()
+ example_3_torchvision_transforms()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py
index 57fb62e10..8a62d92a9 100644
--- a/examples/lekiwi/evaluate.py
+++ b/examples/lekiwi/evaluate.py
@@ -1,31 +1,54 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
-from lerobot.record import record_loop
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
+from lerobot.scripts.lerobot_record import record_loop
+from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
-from lerobot.utils.visualization_utils import _init_rerun
+from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 2
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
+HF_MODEL_ID = "/"
+HF_DATASET_ID = "/"
-# Create the robot and teleoperator configurations
+# Create the robot configuration & robot
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
+
robot = LeKiwiClient(robot_config)
-policy = ACTPolicy.from_pretrained("/")
+# Create policy
+policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
-action_features = hw_to_dataset_features(robot.action_features, "action")
-obs_features = hw_to_dataset_features(robot.observation_features, "observation")
+action_features = hw_to_dataset_features(robot.action_features, ACTION)
+obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
- repo_id="/",
+ repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -33,33 +56,52 @@ dataset = LeRobotDataset.create(
image_writer_threads=4,
)
+# Build Policy Processors
+preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=policy,
+ pretrained_path=HF_MODEL_ID,
+ dataset_stats=dataset.meta.stats,
+ # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
+ preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
+)
+
+# Connect the robot
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
-_init_rerun(session_name="recording")
+# TODO(Steven): Update this example to use pipelines
+teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
+# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
+init_rerun(session_name="lekiwi_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
+print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
- # Run the policy inference loop
+ # Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
+ preprocessor=preprocessor, # Pass the pre and post policy processors
+ postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
+ teleop_action_processor=teleop_action_processor,
+ robot_action_processor=robot_action_processor,
+ robot_observation_processor=robot_observation_processor,
)
- # Logic for reset env
+ # Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
@@ -71,6 +113,9 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
+ teleop_action_processor=teleop_action_processor,
+ robot_action_processor=robot_action_processor,
+ robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
@@ -80,11 +125,12 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
dataset.clear_episode_buffer()
continue
+ # Save episode
dataset.save_episode()
recorded_episodes += 1
-# Upload to hub and clean up
-dataset.push_to_hub()
-
+# Clean up
+log_say("Stop recording")
robot.disconnect()
listener.stop()
+dataset.push_to_hub()
diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py
index 11a716761..9070741bf 100644
--- a/examples/lekiwi/record.py
+++ b/examples/lekiwi/record.py
@@ -1,37 +1,60 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
-from lerobot.record import record_loop
+from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
+from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
+from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
-from lerobot.utils.visualization_utils import _init_rerun
+from lerobot.utils.visualization_utils import init_rerun
-NUM_EPISODES = 3
+NUM_EPISODES = 2
FPS = 30
EPISODE_TIME_SEC = 30
RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
+HF_REPO_ID = "/"
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
keyboard_config = KeyboardTeleopConfig()
+# Initialize the robot and teleoperator
robot = LeKiwiClient(robot_config)
leader_arm = SO100Leader(leader_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
+# TODO(Steven): Update this example to use pipelines
+teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
+
# Configure the dataset features
-action_features = hw_to_dataset_features(robot.action_features, "action")
-obs_features = hw_to_dataset_features(robot.observation_features, "observation")
+action_features = hw_to_dataset_features(robot.action_features, ACTION)
+obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
- repo_id="/",
+ repo_id=HF_REPO_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -39,23 +62,25 @@ dataset = LeRobotDataset.create(
image_writer_threads=4,
)
+# Connect the robot and teleoperator
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
leader_arm.connect()
keyboard.connect()
-_init_rerun(session_name="lekiwi_record")
-
+# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
+init_rerun(session_name="lekiwi_record")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
- raise ValueError("Robot, leader arm of keyboard is not connected!")
+ raise ValueError("Robot or teleop is not connected!")
+print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
- # Run the record loop
+ # Main record loop
record_loop(
robot=robot,
events=events,
@@ -65,9 +90,12 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
+ teleop_action_processor=teleop_action_processor,
+ robot_action_processor=robot_action_processor,
+ robot_observation_processor=robot_observation_processor,
)
- # Logic for reset env
+ # Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
@@ -80,6 +108,9 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
+ teleop_action_processor=teleop_action_processor,
+ robot_action_processor=robot_action_processor,
+ robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
@@ -89,13 +120,14 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
dataset.clear_episode_buffer()
continue
+ # Save episode
dataset.save_episode()
recorded_episodes += 1
-# Upload to hub and clean up
-dataset.push_to_hub()
-
+# Clean up
+log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
+dataset.push_to_hub()
diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py
index 248354df9..3ae915286 100644
--- a/examples/lekiwi/replay.py
+++ b/examples/lekiwi/replay.py
@@ -1,32 +1,60 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
+from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
+# Initialize the robot config
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
+
+# Initialize the robot
robot = LeKiwiClient(robot_config)
+# Fetch the dataset to replay
dataset = LeRobotDataset("/", episodes=[EPISODE_IDX])
-actions = dataset.hf_dataset.select_columns("action")
+# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
+episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
+actions = episode_frames.select_columns(ACTION)
+# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
+print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
-for idx in range(dataset.num_frames):
+for idx in range(len(episode_frames)):
t0 = time.perf_counter()
+ # Get recorded action from dataset
action = {
- name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
+ name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
- robot.send_action(action)
+
+ # Send action to robot
+ _ = robot.send_action(action)
busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py
index 8358a2b93..6b430df48 100644
--- a/examples/lekiwi/teleoperate.py
+++ b/examples/lekiwi/teleoperate.py
@@ -1,10 +1,26 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import time
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.robot_utils import busy_wait
-from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
+from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
@@ -13,35 +29,44 @@ robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
+# Initialize the robot and teleoperator
robot = LeKiwiClient(robot_config)
leader_arm = SO100Leader(teleop_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
+# Connect to the robot and teleoperator
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
leader_arm.connect()
keyboard.connect()
-_init_rerun(session_name="lekiwi_teleop")
+# Init rerun viewer
+init_rerun(session_name="lekiwi_teleop")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
- raise ValueError("Robot, leader arm of keyboard is not connected!")
+ raise ValueError("Robot or teleop is not connected!")
+print("Starting teleop loop...")
while True:
t0 = time.perf_counter()
+ # Get robot observation
observation = robot.get_observation()
+ # Get teleop action
+ # Arm
arm_action = leader_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
-
+ # Keyboard
keyboard_keys = keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
- log_rerun_data(observation, {**arm_action, **base_action})
-
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
- robot.send_action(action)
+ # Send action to robot
+ _ = robot.send_action(action)
+
+ # Visualize
+ log_rerun_data(observation=observation, action=action)
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py
new file mode 100644
index 000000000..0d53f1177
--- /dev/null
+++ b/examples/phone_to_so100/evaluate.py
@@ -0,0 +1,197 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.configs.types import FeatureType, PolicyFeature
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
+from lerobot.datasets.utils import combine_feature_dicts
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.policies.act.modeling_act import ACTPolicy
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.processor import (
+ RobotAction,
+ RobotObservation,
+ RobotProcessorPipeline,
+ make_default_teleop_action_processor,
+)
+from lerobot.processor.converters import (
+ observation_to_transition,
+ robot_action_observation_to_transition,
+ transition_to_observation,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ ForwardKinematicsJointsToEE,
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.scripts.lerobot_record import record_loop
+from lerobot.utils.control_utils import init_keyboard_listener
+from lerobot.utils.utils import log_say
+from lerobot.utils.visualization_utils import init_rerun
+
+NUM_EPISODES = 5
+FPS = 30
+EPISODE_TIME_SEC = 60
+TASK_DESCRIPTION = "My task description"
+HF_MODEL_ID = "/"
+HF_DATASET_ID = "/"
+
+# Create the robot configuration & robot
+camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
+robot_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem58760434471",
+ id="my_awesome_follower_arm",
+ cameras=camera_config,
+ use_degrees=True,
+)
+
+robot = SO100Follower(robot_config)
+
+# Create policy
+policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(robot.bus.motors.keys()),
+)
+
+# Build pipeline to convert EE action to joints action
+robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=True,
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Build pipeline to convert joints observation to EE observation
+robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
+ steps=[
+ ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
+ ],
+ to_transition=observation_to_transition,
+ to_output=transition_to_observation,
+)
+
+# Create the dataset
+dataset = LeRobotDataset.create(
+ repo_id=HF_DATASET_ID,
+ fps=FPS,
+ features=combine_feature_dicts(
+ aggregate_pipeline_dataset_features(
+ pipeline=robot_joints_to_ee_pose_processor,
+ initial_features=create_initial_features(observation=robot.observation_features),
+ use_videos=True,
+ ),
+ # User for now should be explicit on the feature keys that were used for record
+ # Alternatively, the user can pass the processor step that has the right features
+ aggregate_pipeline_dataset_features(
+ pipeline=make_default_teleop_action_processor(),
+ initial_features=create_initial_features(
+ action={
+ f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
+ for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
+ }
+ ),
+ use_videos=True,
+ ),
+ ),
+ robot_type=robot.name,
+ use_videos=True,
+ image_writer_threads=4,
+)
+
+# Build Policy Processors
+preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=policy,
+ pretrained_path=HF_MODEL_ID,
+ dataset_stats=dataset.meta.stats,
+ # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
+ preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
+)
+
+# Connect the robot
+robot.connect()
+
+# Initialize the keyboard listener and rerun visualization
+listener, events = init_keyboard_listener()
+init_rerun(session_name="phone_so100_evaluate")
+
+if not robot.is_connected:
+ raise ValueError("Robot is not connected!")
+
+print("Starting evaluate loop...")
+episode_idx = 0
+for episode_idx in range(NUM_EPISODES):
+ log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
+
+ # Main record loop
+ record_loop(
+ robot=robot,
+ events=events,
+ fps=FPS,
+ policy=policy,
+ preprocessor=preprocessor, # Pass the pre and post policy processors
+ postprocessor=postprocessor,
+ dataset=dataset,
+ control_time_s=EPISODE_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=make_default_teleop_action_processor(),
+ robot_action_processor=robot_ee_to_joints_processor,
+ robot_observation_processor=robot_joints_to_ee_pose_processor,
+ )
+
+ # Reset the environment if not stopping or re-recording
+ if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
+ log_say("Reset the environment")
+ record_loop(
+ robot=robot,
+ events=events,
+ fps=FPS,
+ control_time_s=EPISODE_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=make_default_teleop_action_processor(),
+ robot_action_processor=robot_ee_to_joints_processor,
+ robot_observation_processor=robot_joints_to_ee_pose_processor,
+ )
+
+ if events["rerecord_episode"]:
+ log_say("Re-record episode")
+ events["rerecord_episode"] = False
+ events["exit_early"] = False
+ dataset.clear_episode_buffer()
+ continue
+
+ # Save episode
+ dataset.save_episode()
+ episode_idx += 1
+
+# Clean up
+log_say("Stop recording")
+robot.disconnect()
+listener.stop()
+dataset.push_to_hub()
diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py
new file mode 100644
index 000000000..d3ef293a7
--- /dev/null
+++ b/examples/phone_to_so100/record.py
@@ -0,0 +1,203 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
+from lerobot.datasets.utils import combine_feature_dicts
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
+from lerobot.processor.converters import (
+ observation_to_transition,
+ robot_action_observation_to_transition,
+ transition_to_observation,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ EEBoundsAndSafety,
+ EEReferenceAndDelta,
+ ForwardKinematicsJointsToEE,
+ GripperVelocityToJoint,
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.scripts.lerobot_record import record_loop
+from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
+from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
+from lerobot.teleoperators.phone.teleop_phone import Phone
+from lerobot.utils.control_utils import init_keyboard_listener
+from lerobot.utils.utils import log_say
+from lerobot.utils.visualization_utils import init_rerun
+
+NUM_EPISODES = 2
+FPS = 30
+EPISODE_TIME_SEC = 60
+RESET_TIME_SEC = 30
+TASK_DESCRIPTION = "My task description"
+HF_REPO_ID = "/"
+
+# Create the robot and teleoperator configurations
+camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
+robot_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem5A460814411",
+ id="my_awesome_follower_arm",
+ cameras=camera_config,
+ use_degrees=True,
+)
+teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
+
+# Initialize the robot and teleoperator
+robot = SO100Follower(robot_config)
+phone = Phone(teleop_config)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(robot.bus.motors.keys()),
+)
+
+# Build pipeline to convert phone action to EE action
+phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
+ EEReferenceAndDelta(
+ kinematics=kinematics_solver,
+ end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
+ motor_names=list(robot.bus.motors.keys()),
+ use_latched_reference=True,
+ ),
+ EEBoundsAndSafety(
+ end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
+ max_ee_step_m=0.20,
+ ),
+ GripperVelocityToJoint(speed_factor=20.0),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Build pipeline to convert EE action to joints action
+robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=True,
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Build pipeline to convert joint observation to EE observation
+robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
+ steps=[
+ ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
+ ],
+ to_transition=observation_to_transition,
+ to_output=transition_to_observation,
+)
+
+# Create the dataset
+dataset = LeRobotDataset.create(
+ repo_id=HF_REPO_ID,
+ fps=FPS,
+ features=combine_feature_dicts(
+ # Run the feature contract of the pipelines
+ # This tells you how the features would look like after the pipeline steps
+ aggregate_pipeline_dataset_features(
+ pipeline=phone_to_robot_ee_pose_processor,
+ initial_features=create_initial_features(action=phone.action_features),
+ use_videos=True,
+ ),
+ aggregate_pipeline_dataset_features(
+ pipeline=robot_joints_to_ee_pose,
+ initial_features=create_initial_features(observation=robot.observation_features),
+ use_videos=True,
+ ),
+ ),
+ robot_type=robot.name,
+ use_videos=True,
+ image_writer_threads=4,
+)
+
+# Connect the robot and teleoperator
+robot.connect()
+phone.connect()
+
+# Initialize the keyboard listener and rerun visualization
+listener, events = init_keyboard_listener()
+init_rerun(session_name="phone_so100_record")
+
+if not robot.is_connected or not phone.is_connected:
+ raise ValueError("Robot or teleop is not connected!")
+
+
+print("Starting record loop. Move your phone to teleoperate the robot...")
+episode_idx = 0
+while episode_idx < NUM_EPISODES and not events["stop_recording"]:
+ log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
+
+ # Main record loop
+ record_loop(
+ robot=robot,
+ events=events,
+ fps=FPS,
+ teleop=phone,
+ dataset=dataset,
+ control_time_s=EPISODE_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=phone_to_robot_ee_pose_processor,
+ robot_action_processor=robot_ee_to_joints_processor,
+ robot_observation_processor=robot_joints_to_ee_pose,
+ )
+
+ # Reset the environment if not stopping or re-recording
+ if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
+ log_say("Reset the environment")
+ record_loop(
+ robot=robot,
+ events=events,
+ fps=FPS,
+ teleop=phone,
+ control_time_s=RESET_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=phone_to_robot_ee_pose_processor,
+ robot_action_processor=robot_ee_to_joints_processor,
+ robot_observation_processor=robot_joints_to_ee_pose,
+ )
+
+ if events["rerecord_episode"]:
+ log_say("Re-recording episode")
+ events["rerecord_episode"] = False
+ events["exit_early"] = False
+ dataset.clear_episode_buffer()
+ continue
+
+ # Save episode
+ dataset.save_episode()
+ episode_idx += 1
+
+# Clean up
+log_say("Stop recording")
+robot.disconnect()
+phone.disconnect()
+listener.stop()
+dataset.push_to_hub()
diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py
new file mode 100644
index 000000000..f1181143c
--- /dev/null
+++ b/examples/phone_to_so100/replay.py
@@ -0,0 +1,100 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
+from lerobot.processor.converters import (
+ robot_action_observation_to_transition,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.utils.constants import ACTION
+from lerobot.utils.robot_utils import busy_wait
+from lerobot.utils.utils import log_say
+
+EPISODE_IDX = 0
+HF_REPO_ID = "/"
+
+# Initialize the robot config
+robot_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
+)
+
+# Initialize the robot
+robot = SO100Follower(robot_config)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(robot.bus.motors.keys()),
+)
+
+# Build pipeline to convert EE action to joints action
+robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=False, # Because replay is open loop
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Fetch the dataset to replay
+dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
+# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
+episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
+actions = episode_frames.select_columns(ACTION)
+
+# Connect to the robot
+robot.connect()
+
+if not robot.is_connected:
+ raise ValueError("Robot is not connected!")
+
+print("Starting replay loop...")
+log_say(f"Replaying episode {EPISODE_IDX}")
+for idx in range(len(episode_frames)):
+ t0 = time.perf_counter()
+
+ # Get recorded action from dataset
+ ee_action = {
+ name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
+ }
+
+ # Get robot observation
+ robot_obs = robot.get_observation()
+
+ # Dataset EE -> robot joints
+ joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
+
+ # Send action to robot
+ _ = robot.send_action(joint_action)
+
+ busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
+
+# Clean up
+robot.disconnect()
diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py
new file mode 100644
index 000000000..783dce242
--- /dev/null
+++ b/examples/phone_to_so100/teleoperate.py
@@ -0,0 +1,113 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specif
+
+import time
+
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
+from lerobot.processor.converters import (
+ robot_action_observation_to_transition,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ EEBoundsAndSafety,
+ EEReferenceAndDelta,
+ GripperVelocityToJoint,
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
+from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
+from lerobot.teleoperators.phone.teleop_phone import Phone
+from lerobot.utils.robot_utils import busy_wait
+from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+
+FPS = 30
+
+# Initialize the robot and teleoperator
+robot_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
+)
+teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
+
+# Initialize the robot and teleoperator
+robot = SO100Follower(robot_config)
+teleop_device = Phone(teleop_config)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(robot.bus.motors.keys()),
+)
+
+# Build pipeline to convert phone action to ee pose action to joint action
+phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
+ EEReferenceAndDelta(
+ kinematics=kinematics_solver,
+ end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
+ motor_names=list(robot.bus.motors.keys()),
+ use_latched_reference=True,
+ ),
+ EEBoundsAndSafety(
+ end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
+ max_ee_step_m=0.10,
+ ),
+ GripperVelocityToJoint(
+ speed_factor=20.0,
+ ),
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=True,
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Connect to the robot and teleoperator
+robot.connect()
+teleop_device.connect()
+
+# Init rerun viewer
+init_rerun(session_name="phone_so100_teleop")
+
+if not robot.is_connected or not teleop_device.is_connected:
+ raise ValueError("Robot or teleop is not connected!")
+
+print("Starting teleop loop. Move your phone to teleoperate the robot...")
+while True:
+ t0 = time.perf_counter()
+
+ # Get robot observation
+ robot_obs = robot.get_observation()
+
+ # Get teleop action
+ phone_obs = teleop_device.get_action()
+
+ # Phone -> EE pose -> Joints transition
+ joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
+
+ # Send action to robot
+ _ = robot.send_action(joint_action)
+
+ # Visualize
+ log_rerun_data(observation=phone_obs, action=joint_action)
+
+ busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
diff --git a/examples/port_datasets/display_error_files.py b/examples/port_datasets/display_error_files.py
new file mode 100644
index 000000000..fffab5ff3
--- /dev/null
+++ b/examples/port_datasets/display_error_files.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+from pathlib import Path
+
+
+def find_missing_workers(completions_dir, world_size):
+ """Find workers that are not completed and returns their indices."""
+ full = list(range(world_size))
+
+ completed = []
+ for path in completions_dir.glob("*"):
+ if path.name in [".", ".."]:
+ continue
+ index = path.name.lstrip("0")
+ index = 0 if index == "" else int(index)
+ completed.append(index)
+
+ missing_workers = set(full) - set(completed)
+ return missing_workers
+
+
+def find_output_files(slurm_dir, worker_indices):
+ """Find output files associated to worker indices, and return tuples
+ of (worker index, output file path)
+ """
+ out_files = []
+ for path in slurm_dir.glob("*.out"):
+ _, worker_id = path.name.replace(".out", "").split("_")
+ worker_id = int(worker_id)
+ if worker_id in worker_indices:
+ out_files.append((worker_id, path))
+ return out_files
+
+
+def display_error_files(logs_dir, job_name):
+ executor_path = Path(logs_dir) / job_name / "executor.json"
+ completions_dir = Path(logs_dir) / job_name / "completions"
+
+ with open(executor_path) as f:
+ executor = json.load(f)
+
+ missing_workers = find_missing_workers(completions_dir, executor["world_size"])
+
+ for missing in sorted(missing_workers)[::-1]:
+ print(missing)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--logs-dir",
+ type=str,
+ help="Path to logs directory for `datatrove`.",
+ )
+ parser.add_argument(
+ "--job-name",
+ type=str,
+ default="port_droid",
+ help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
+ )
+
+ args = parser.parse_args()
+
+ display_error_files(**vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/port_datasets/port_droid.py b/examples/port_datasets/port_droid.py
new file mode 100644
index 000000000..4efb131e4
--- /dev/null
+++ b/examples/port_datasets/port_droid.py
@@ -0,0 +1,430 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import time
+from pathlib import Path
+
+import numpy as np
+import tensorflow_datasets as tfds
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
+from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
+
+DROID_SHARDS = 2048
+DROID_FPS = 15
+DROID_ROBOT_TYPE = "Franka"
+
+# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
+DROID_FEATURES = {
+ # true on first step of the episode
+ "is_first": {
+ "dtype": "bool",
+ "shape": (1,),
+ "names": None,
+ },
+ # true on last step of the episode
+ "is_last": {
+ "dtype": "bool",
+ "shape": (1,),
+ "names": None,
+ },
+ # true on last step of the episode if it is a terminal step, True for demos
+ "is_terminal": {
+ "dtype": "bool",
+ "shape": (1,),
+ "names": None,
+ },
+ # language_instruction is also stored as "task" to follow LeRobot standard
+ "language_instruction": {
+ "dtype": "string",
+ "shape": (1,),
+ "names": None,
+ },
+ "language_instruction_2": {
+ "dtype": "string",
+ "shape": (1,),
+ "names": None,
+ },
+ "language_instruction_3": {
+ "dtype": "string",
+ "shape": (1,),
+ "names": None,
+ },
+ "observation.state.gripper_position": {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": {
+ "axes": ["gripper"],
+ },
+ },
+ "observation.state.cartesian_position": {
+ "dtype": "float32",
+ "shape": (6,),
+ "names": {
+ "axes": ["x", "y", "z", "roll", "pitch", "yaw"],
+ },
+ },
+ "observation.state.joint_position": {
+ "dtype": "float32",
+ "shape": (7,),
+ "names": {
+ "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
+ },
+ },
+ # Add this new feature to follow LeRobot standard of using joint position + gripper
+ "observation.state": {
+ "dtype": "float32",
+ "shape": (8,),
+ "names": {
+ "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
+ },
+ },
+ # Initially called wrist_image_left
+ "observation.images.wrist_left": {
+ "dtype": "video",
+ "shape": (180, 320, 3),
+ "names": [
+ "height",
+ "width",
+ "channels",
+ ],
+ },
+ # Initially called exterior_image_1_left
+ "observation.images.exterior_1_left": {
+ "dtype": "video",
+ "shape": (180, 320, 3),
+ "names": [
+ "height",
+ "width",
+ "channels",
+ ],
+ },
+ # Initially called exterior_image_2_left
+ "observation.images.exterior_2_left": {
+ "dtype": "video",
+ "shape": (180, 320, 3),
+ "names": [
+ "height",
+ "width",
+ "channels",
+ ],
+ },
+ "action.gripper_position": {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": {
+ "axes": ["gripper"],
+ },
+ },
+ "action.gripper_velocity": {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": {
+ "axes": ["gripper"],
+ },
+ },
+ "action.cartesian_position": {
+ "dtype": "float32",
+ "shape": (6,),
+ "names": {
+ "axes": ["x", "y", "z", "roll", "pitch", "yaw"],
+ },
+ },
+ "action.cartesian_velocity": {
+ "dtype": "float32",
+ "shape": (6,),
+ "names": {
+ "axes": ["x", "y", "z", "roll", "pitch", "yaw"],
+ },
+ },
+ "action.joint_position": {
+ "dtype": "float32",
+ "shape": (7,),
+ "names": {
+ "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
+ },
+ },
+ "action.joint_velocity": {
+ "dtype": "float32",
+ "shape": (7,),
+ "names": {
+ "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
+ },
+ },
+ # This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position]
+ "action.original": {
+ "dtype": "float32",
+ "shape": (7,),
+ "names": {
+ "axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
+ },
+ },
+ # Add this new feature to follow LeRobot standard of using joint position + gripper
+ "action": {
+ "dtype": "float32",
+ "shape": (8,),
+ "names": {
+ "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
+ },
+ },
+ "discount": {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": None,
+ },
+ "reward": {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": None,
+ },
+ # Meta data that are the same for all frames in the episode
+ "task_category": {
+ "dtype": "string",
+ "shape": (1,),
+ "names": None,
+ },
+ "building": {
+ "dtype": "string",
+ "shape": (1,),
+ "names": None,
+ },
+ "collector_id": {
+ "dtype": "string",
+ "shape": (1,),
+ "names": None,
+ },
+ "date": {
+ "dtype": "string",
+ "shape": (1,),
+ "names": None,
+ },
+ "camera_extrinsics.wrist_left": {
+ "dtype": "float32",
+ "shape": (6,),
+ "names": {
+ "axes": ["x", "y", "z", "roll", "pitch", "yaw"],
+ },
+ },
+ "camera_extrinsics.exterior_1_left": {
+ "dtype": "float32",
+ "shape": (6,),
+ "names": {
+ "axes": ["x", "y", "z", "roll", "pitch", "yaw"],
+ },
+ },
+ "camera_extrinsics.exterior_2_left": {
+ "dtype": "float32",
+ "shape": (6,),
+ "names": {
+ "axes": ["x", "y", "z", "roll", "pitch", "yaw"],
+ },
+ },
+ "is_episode_successful": {
+ "dtype": "bool",
+ "shape": (1,),
+ "names": None,
+ },
+}
+
+
+def is_episode_successful(tf_episode_metadata):
+ # Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8
+ return "/success/" in tf_episode_metadata["file_path"].numpy().decode()
+
+
+def generate_lerobot_frames(tf_episode):
+ m = tf_episode["episode_metadata"]
+ frame_meta = {
+ "task_category": m["building"].numpy().decode(),
+ "building": m["building"].numpy().decode(),
+ "collector_id": m["collector_id"].numpy().decode(),
+ "date": m["date"].numpy().decode(),
+ "camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(),
+ "camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(),
+ "camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(),
+ "is_episode_successful": np.array([is_episode_successful(m)]),
+ }
+ for f in tf_episode["steps"]:
+ # Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
+ frame = {
+ "is_first": np.array([f["is_first"].numpy()]),
+ "is_last": np.array([f["is_last"].numpy()]),
+ "is_terminal": np.array([f["is_terminal"].numpy()]),
+ "language_instruction": f["language_instruction"].numpy().decode(),
+ "language_instruction_2": f["language_instruction_2"].numpy().decode(),
+ "language_instruction_3": f["language_instruction_3"].numpy().decode(),
+ "observation.state.gripper_position": f["observation"]["gripper_position"].numpy(),
+ "observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(),
+ "observation.state.joint_position": f["observation"]["joint_position"].numpy(),
+ "observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(),
+ "observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(),
+ "observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(),
+ "action.gripper_position": f["action_dict"]["gripper_position"].numpy(),
+ "action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(),
+ "action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(),
+ "action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(),
+ "action.joint_position": f["action_dict"]["joint_position"].numpy(),
+ "action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(),
+ "discount": np.array([f["discount"].numpy()]),
+ "reward": np.array([f["reward"].numpy()]),
+ "action.original": f["action"].numpy(),
+ }
+
+ # language_instruction is also stored as "task" to follow LeRobot standard
+ frame["task"] = frame["language_instruction"]
+
+ # Add this new feature to follow LeRobot standard of using joint position + gripper
+ frame["observation.state"] = np.concatenate(
+ [frame["observation.state.joint_position"], frame["observation.state.gripper_position"]]
+ )
+ frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]])
+
+ # Meta data that are the same for all frames in the episode
+ frame.update(frame_meta)
+
+ # Cast fp64 to fp32
+ for key in frame:
+ if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
+ frame[key] = frame[key].astype(np.float32)
+
+ yield frame
+
+
+def port_droid(
+ raw_dir: Path,
+ repo_id: str,
+ push_to_hub: bool = False,
+ num_shards: int | None = None,
+ shard_index: int | None = None,
+):
+ dataset_name = raw_dir.parent.name
+ version = raw_dir.name
+ data_dir = raw_dir.parent.parent
+
+ builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="")
+
+ if num_shards is not None:
+ tfds_num_shards = builder.info.splits["train"].num_shards
+ if tfds_num_shards != DROID_SHARDS:
+ raise ValueError(
+ f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}."
+ )
+ if num_shards != tfds_num_shards:
+ raise ValueError(
+ f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead."
+ )
+ if shard_index >= tfds_num_shards:
+ raise ValueError(
+ f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})."
+ )
+
+ raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
+ else:
+ raw_dataset = builder.as_dataset(split="train")
+
+ lerobot_dataset = LeRobotDataset.create(
+ repo_id=repo_id,
+ robot_type=DROID_ROBOT_TYPE,
+ fps=DROID_FPS,
+ features=DROID_FEATURES,
+ )
+
+ start_time = time.time()
+ num_episodes = raw_dataset.cardinality().numpy().item()
+ logging.info(f"Number of episodes {num_episodes}")
+
+ for episode_index, episode in enumerate(raw_dataset):
+ elapsed_time = time.time() - start_time
+ d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
+
+ logging.info(
+ f"{episode_index} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
+ )
+
+ for frame in generate_lerobot_frames(episode):
+ lerobot_dataset.add_frame(frame)
+
+ lerobot_dataset.save_episode()
+ logging.info("Save_episode")
+
+ if push_to_hub:
+ lerobot_dataset.push_to_hub(
+ # Add openx tag, since it belongs to the openx collection of datasets
+ tags=["openx"],
+ private=False,
+ )
+
+
+def validate_dataset(repo_id):
+ """Sanity check that ensure meta data can be loaded and all files are present."""
+ meta = LeRobotDatasetMetadata(repo_id)
+
+ if meta.total_episodes == 0:
+ raise ValueError("Number of episodes is 0.")
+
+ for ep_idx in range(meta.total_episodes):
+ data_path = meta.root / meta.get_data_file_path(ep_idx)
+
+ if not data_path.exists():
+ raise ValueError(f"Parquet file is missing in: {data_path}")
+
+ for vid_key in meta.video_keys:
+ vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
+ if not vid_path.exists():
+ raise ValueError(f"Video file is missing in: {vid_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--raw-dir",
+ type=Path,
+ required=True,
+ help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
+ )
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
+ )
+ parser.add_argument(
+ "--push-to-hub",
+ action="store_true",
+ help="Upload to hub.",
+ )
+ parser.add_argument(
+ "--num-shards",
+ type=int,
+ default=None,
+ help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.",
+ )
+ parser.add_argument(
+ "--shard-index",
+ type=int,
+ default=None,
+ help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.",
+ )
+
+ args = parser.parse_args()
+
+ port_droid(**vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/port_datasets/slurm_aggregate_shards.py b/examples/port_datasets/slurm_aggregate_shards.py
new file mode 100644
index 000000000..4e1b71a31
--- /dev/null
+++ b/examples/port_datasets/slurm_aggregate_shards.py
@@ -0,0 +1,148 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from pathlib import Path
+
+from datatrove.executor import LocalPipelineExecutor
+from datatrove.executor.slurm import SlurmPipelineExecutor
+from datatrove.pipeline.base import PipelineStep
+from port_datasets.droid_rlds.port_droid import DROID_SHARDS
+
+from lerobot.datasets.aggregate import aggregate_datasets
+from lerobot.utils.utils import init_logging
+
+
+class AggregateDatasets(PipelineStep):
+ def __init__(
+ self,
+ repo_ids: list[str],
+ aggregated_repo_id: str,
+ ):
+ super().__init__()
+ self.repo_ids = repo_ids
+ self.aggr_repo_id = aggregated_repo_id
+
+ def run(self, data=None, rank: int = 0, world_size: int = 1):
+ init_logging()
+
+ # Since aggregate_datasets already handles parallel processing internally,
+ # we only need one worker to run the entire aggregation
+ if rank == 0:
+ logging.info(f"Starting aggregation of {len(self.repo_ids)} datasets into {self.aggr_repo_id}")
+ aggregate_datasets(self.repo_ids, self.aggr_repo_id)
+ logging.info("Aggregation complete!")
+ else:
+ logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
+
+
+def make_aggregate_executor(
+ repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
+):
+ kwargs = {
+ "pipeline": [
+ AggregateDatasets(repo_ids, repo_id),
+ ],
+ "logging_dir": str(logs_dir / job_name),
+ }
+
+ if slurm:
+ # For aggregation, we only need 1 task since aggregate_datasets handles everything
+ kwargs.update(
+ {
+ "job_name": job_name,
+ "tasks": 1, # Only need 1 task for aggregation
+ "workers": 1, # Only need 1 worker
+ "time": "08:00:00",
+ "partition": partition,
+ "cpus_per_task": cpus_per_task,
+ "sbatch_args": {"mem-per-cpu": mem_per_cpu},
+ }
+ )
+ executor = SlurmPipelineExecutor(**kwargs)
+ else:
+ kwargs.update(
+ {
+ "tasks": 1,
+ "workers": 1,
+ }
+ )
+ executor = LocalPipelineExecutor(**kwargs)
+
+ return executor
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
+ )
+ parser.add_argument(
+ "--logs-dir",
+ type=Path,
+ help="Path to logs directory for `datatrove`.",
+ )
+ parser.add_argument(
+ "--job-name",
+ type=str,
+ default="aggr_droid",
+ help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
+ )
+ parser.add_argument(
+ "--slurm",
+ type=int,
+ default=1,
+ help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=1, # Changed default to 1 since aggregation doesn't need multiple workers
+ help="Number of slurm workers. For aggregation, this should be 1.",
+ )
+ parser.add_argument(
+ "--partition",
+ type=str,
+ help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
+ )
+ parser.add_argument(
+ "--cpus-per-task",
+ type=int,
+ default=8,
+ help="Number of cpus that each slurm worker will use.",
+ )
+ parser.add_argument(
+ "--mem-per-cpu",
+ type=str,
+ default="1950M",
+ help="Memory per cpu that each worker will use.",
+ )
+
+ args = parser.parse_args()
+ kwargs = vars(args)
+ kwargs["slurm"] = kwargs.pop("slurm") == 1
+
+ repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)]
+ aggregate_executor = make_aggregate_executor(repo_ids, **kwargs)
+ aggregate_executor.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/port_datasets/slurm_port_shards.py b/examples/port_datasets/slurm_port_shards.py
new file mode 100644
index 000000000..3bb4c135c
--- /dev/null
+++ b/examples/port_datasets/slurm_port_shards.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+from pathlib import Path
+
+from datatrove.executor import LocalPipelineExecutor
+from datatrove.executor.slurm import SlurmPipelineExecutor
+from datatrove.pipeline.base import PipelineStep
+from port_datasets.droid_rlds.port_droid import DROID_SHARDS
+
+
+class PortDroidShards(PipelineStep):
+ def __init__(
+ self,
+ raw_dir: Path | str,
+ repo_id: str = None,
+ ):
+ super().__init__()
+ self.raw_dir = Path(raw_dir)
+ self.repo_id = repo_id
+
+ def run(self, data=None, rank: int = 0, world_size: int = 1):
+ from datasets.utils.tqdm import disable_progress_bars
+ from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
+
+ from lerobot.utils.utils import init_logging
+
+ init_logging()
+ disable_progress_bars()
+
+ shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
+
+ try:
+ validate_dataset(shard_repo_id)
+ return
+ except Exception:
+ pass # nosec B110 - Dataset doesn't exist yet, continue with porting
+
+ port_droid(
+ self.raw_dir,
+ shard_repo_id,
+ push_to_hub=False,
+ num_shards=world_size,
+ shard_index=rank,
+ )
+
+ validate_dataset(shard_repo_id)
+
+
+def make_port_executor(
+ raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
+):
+ kwargs = {
+ "pipeline": [
+ PortDroidShards(raw_dir, repo_id),
+ ],
+ "logging_dir": str(logs_dir / job_name),
+ }
+
+ if slurm:
+ kwargs.update(
+ {
+ "job_name": job_name,
+ "tasks": DROID_SHARDS,
+ "workers": workers,
+ "time": "08:00:00",
+ "partition": partition,
+ "cpus_per_task": cpus_per_task,
+ "sbatch_args": {"mem-per-cpu": mem_per_cpu},
+ }
+ )
+ executor = SlurmPipelineExecutor(**kwargs)
+ else:
+ kwargs.update(
+ {
+ "tasks": 1,
+ "workers": 1,
+ }
+ )
+ executor = LocalPipelineExecutor(**kwargs)
+
+ return executor
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--raw-dir",
+ type=Path,
+ required=True,
+ help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
+ )
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
+ )
+ parser.add_argument(
+ "--logs-dir",
+ type=Path,
+ help="Path to logs directory for `datatrove`.",
+ )
+ parser.add_argument(
+ "--job-name",
+ type=str,
+ default="port_droid",
+ help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
+ )
+ parser.add_argument(
+ "--slurm",
+ type=int,
+ default=1,
+ help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=2048,
+ help="Number of slurm workers. It should be less than the maximum number of shards.",
+ )
+ parser.add_argument(
+ "--partition",
+ type=str,
+ help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
+ )
+ parser.add_argument(
+ "--cpus-per-task",
+ type=int,
+ default=8,
+ help="Number of cpus that each slurm worker will use.",
+ )
+ parser.add_argument(
+ "--mem-per-cpu",
+ type=str,
+ default="1950M",
+ help="Memory per cpu that each worker will use.",
+ )
+
+ args = parser.parse_args()
+ kwargs = vars(args)
+ kwargs["slurm"] = kwargs.pop("slurm") == 1
+ port_executor = make_port_executor(**kwargs)
+ port_executor.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/port_datasets/slurm_upload.py b/examples/port_datasets/slurm_upload.py
new file mode 100644
index 000000000..ade1ef874
--- /dev/null
+++ b/examples/port_datasets/slurm_upload.py
@@ -0,0 +1,281 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+from datatrove.executor import LocalPipelineExecutor
+from datatrove.executor.slurm import SlurmPipelineExecutor
+from datatrove.pipeline.base import PipelineStep
+from huggingface_hub import HfApi
+from huggingface_hub.constants import REPOCARD_NAME
+from port_datasets.droid_rlds.port_droid import DROID_SHARDS
+
+from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
+from lerobot.datasets.utils import create_lerobot_dataset_card
+from lerobot.utils.utils import init_logging
+
+
+class UploadDataset(PipelineStep):
+ def __init__(
+ self,
+ repo_id: str,
+ branch: str | None = None,
+ revision: str | None = None,
+ tags: list | None = None,
+ license: str | None = "apache-2.0",
+ private: bool = False,
+ distant_repo_id: str | None = None,
+ **card_kwargs,
+ ):
+ super().__init__()
+ self.repo_id = repo_id
+ self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id
+ self.branch = branch
+ self.tags = tags
+ self.license = license
+ self.private = private
+ self.card_kwargs = card_kwargs
+ self.revision = revision if revision else CODEBASE_VERSION
+
+ if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
+ logging.warning(
+ 'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
+ "variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
+ )
+
+ self.create_repo()
+
+ def create_repo(self):
+ logging.info(f"Loading meta data from {self.repo_id}...")
+ meta = LeRobotDatasetMetadata(self.repo_id)
+
+ logging.info(f"Creating repo {self.distant_repo_id}...")
+ hub_api = HfApi()
+ hub_api.create_repo(
+ repo_id=self.distant_repo_id,
+ private=self.private,
+ repo_type="dataset",
+ exist_ok=True,
+ )
+ if self.branch:
+ hub_api.create_branch(
+ repo_id=self.distant_repo_id,
+ branch=self.branch,
+ revision=self.revision,
+ repo_type="dataset",
+ exist_ok=True,
+ )
+
+ if not hub_api.file_exists(
+ self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch
+ ):
+ card = create_lerobot_dataset_card(
+ tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
+ )
+ card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
+
+ hub_api.create_tag(self.distant_repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
+
+ def list_files_recursively(directory):
+ base_path = Path(directory)
+ return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
+
+ logging.info(f"Listing all local files from {self.repo_id}...")
+ self.file_paths = list_files_recursively(meta.root)
+ self.file_paths = sorted(self.file_paths)
+
+ def create_chunks(self, lst, n):
+ from itertools import islice
+
+ it = iter(lst)
+ return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]
+
+ def create_commits(self, additions):
+ import logging
+ import math
+ import random
+ import time
+
+ from huggingface_hub import create_commit
+ from huggingface_hub.utils import HfHubHTTPError
+
+ FILES_BETWEEN_COMMITS = 10 # noqa: N806
+ BASE_DELAY = 0.1 # noqa: N806
+ MAX_RETRIES = 12 # noqa: N806
+
+ # Split the files into smaller chunks for faster commit
+ # and avoiding "A commit has happened since" error
+ num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS)
+ chunks = self.create_chunks(additions, num_chunks)
+
+ for chunk in chunks:
+ retries = 0
+ while True:
+ try:
+ create_commit(
+ self.distant_repo_id,
+ repo_type="dataset",
+ operations=chunk,
+ commit_message=f"DataTrove upload ({len(chunk)} files)",
+ revision=self.branch,
+ )
+ # TODO: every 100 chunks super_squach_commits()
+ logging.info("create_commit completed!")
+ break
+ except HfHubHTTPError as e:
+ if "A commit has happened since" in e.server_message:
+ if retries >= MAX_RETRIES:
+ logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
+ raise e
+ logging.info("Commit creation race condition issue. Waiting...")
+ time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
+ retries += 1
+ else:
+ raise e
+
+ def run(self, data=None, rank: int = 0, world_size: int = 1):
+ import logging
+
+ from datasets.utils.tqdm import disable_progress_bars
+ from huggingface_hub import CommitOperationAdd, preupload_lfs_files
+
+ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
+ from lerobot.utils.utils import init_logging
+
+ init_logging()
+ disable_progress_bars()
+
+ chunks = self.create_chunks(self.file_paths, world_size)
+ file_paths = chunks[rank]
+
+ if len(file_paths) == 0:
+ raise ValueError(file_paths)
+
+ logging.info("Pre-uploading LFS files...")
+ for i, path in enumerate(file_paths):
+ logging.info(f"{i}: {path}")
+
+ meta = LeRobotDatasetMetadata(self.repo_id)
+ additions = [
+ CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
+ ]
+ preupload_lfs_files(
+ repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch
+ )
+
+ logging.info("Creating commits...")
+ self.create_commits(additions)
+ logging.info("Done!")
+
+
+def make_upload_executor(
+ repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
+):
+ kwargs = {
+ "pipeline": [
+ UploadDataset(repo_id),
+ ],
+ "logging_dir": str(logs_dir / job_name),
+ }
+
+ if slurm:
+ kwargs.update(
+ {
+ "job_name": job_name,
+ "tasks": DROID_SHARDS,
+ "workers": workers,
+ "time": "08:00:00",
+ "partition": partition,
+ "cpus_per_task": cpus_per_task,
+ "sbatch_args": {"mem-per-cpu": mem_per_cpu},
+ }
+ )
+ executor = SlurmPipelineExecutor(**kwargs)
+ else:
+ kwargs.update(
+ {
+ "tasks": DROID_SHARDS,
+ "workers": 1,
+ }
+ )
+ executor = LocalPipelineExecutor(**kwargs)
+
+ return executor
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
+ )
+ parser.add_argument(
+ "--logs-dir",
+ type=Path,
+ help="Path to logs directory for `datatrove`.",
+ )
+ parser.add_argument(
+ "--job-name",
+ type=str,
+ default="upload_droid",
+ help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
+ )
+ parser.add_argument(
+ "--slurm",
+ type=int,
+ default=1,
+ help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=50,
+ help="Number of slurm workers. It should be less than the maximum number of shards.",
+ )
+ parser.add_argument(
+ "--partition",
+ type=str,
+ help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
+ )
+ parser.add_argument(
+ "--cpus-per-task",
+ type=int,
+ default=8,
+ help="Number of cpus that each slurm worker will use.",
+ )
+ parser.add_argument(
+ "--mem-per-cpu",
+ type=str,
+ default="1950M",
+ help="Memory per cpu that each worker will use.",
+ )
+
+ init_logging()
+
+ args = parser.parse_args()
+ kwargs = vars(args)
+ kwargs["slurm"] = kwargs.pop("slurm") == 1
+ upload_executor = make_upload_executor(**kwargs)
+ upload_executor.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py
new file mode 100644
index 000000000..53a385442
--- /dev/null
+++ b/examples/so100_to_so100_EE/evaluate.py
@@ -0,0 +1,198 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.configs.types import FeatureType, PolicyFeature
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
+from lerobot.datasets.utils import combine_feature_dicts
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.policies.act.modeling_act import ACTPolicy
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.processor import (
+ RobotAction,
+ RobotObservation,
+ RobotProcessorPipeline,
+ make_default_teleop_action_processor,
+)
+from lerobot.processor.converters import (
+ observation_to_transition,
+ robot_action_observation_to_transition,
+ transition_to_observation,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ ForwardKinematicsJointsToEE,
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.scripts.lerobot_record import record_loop
+from lerobot.utils.control_utils import init_keyboard_listener
+from lerobot.utils.utils import log_say
+from lerobot.utils.visualization_utils import init_rerun
+
+NUM_EPISODES = 5
+FPS = 30
+EPISODE_TIME_SEC = 60
+TASK_DESCRIPTION = "My task description"
+HF_MODEL_ID = "/"
+HF_DATASET_ID = "/"
+
+# Create the robot configuration & robot
+camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
+robot_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem5A460814411",
+ id="my_awesome_follower_arm",
+ cameras=camera_config,
+ use_degrees=True,
+)
+
+robot = SO100Follower(robot_config)
+
+# Create policy
+policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(robot.bus.motors.keys()),
+)
+
+# Build pipeline to convert EE action to joints action
+robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=True,
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Build pipeline to convert joints observation to EE observation
+robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
+ steps=[
+ ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
+ ],
+ to_transition=observation_to_transition,
+ to_output=transition_to_observation,
+)
+
+
+# Create the dataset
+dataset = LeRobotDataset.create(
+ repo_id=HF_DATASET_ID,
+ fps=FPS,
+ features=combine_feature_dicts(
+ aggregate_pipeline_dataset_features(
+ pipeline=robot_joints_to_ee_pose_processor,
+ initial_features=create_initial_features(observation=robot.observation_features),
+ use_videos=True,
+ ),
+ # User for now should be explicit on the feature keys that were used for record
+ # Alternatively, the user can pass the processor step that has the right features
+ aggregate_pipeline_dataset_features(
+ pipeline=make_default_teleop_action_processor(),
+ initial_features=create_initial_features(
+ action={
+ f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
+ for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
+ }
+ ),
+ use_videos=True,
+ ),
+ ),
+ robot_type=robot.name,
+ use_videos=True,
+ image_writer_threads=4,
+)
+
+# Build Policy Processors
+preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=policy,
+ pretrained_path=HF_MODEL_ID,
+ dataset_stats=dataset.meta.stats,
+ # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
+ preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
+)
+
+# Connect the robot and teleoperator
+robot.connect()
+
+# Initialize the keyboard listener and rerun visualization
+listener, events = init_keyboard_listener()
+init_rerun(session_name="so100_so100_evaluate")
+
+if not robot.is_connected:
+ raise ValueError("Robot is not connected!")
+
+print("Starting evaluate loop...")
+episode_idx = 0
+for episode_idx in range(NUM_EPISODES):
+ log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
+
+ # Main record loop
+ record_loop(
+ robot=robot,
+ events=events,
+ fps=FPS,
+ policy=policy,
+ preprocessor=preprocessor, # Pass the pre and post policy processors
+ postprocessor=postprocessor,
+ dataset=dataset,
+ control_time_s=EPISODE_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=make_default_teleop_action_processor(),
+ robot_action_processor=robot_ee_to_joints_processor,
+ robot_observation_processor=robot_joints_to_ee_pose_processor,
+ )
+
+ # Reset the environment if not stopping or re-recording
+ if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
+ log_say("Reset the environment")
+ record_loop(
+ robot=robot,
+ events=events,
+ fps=FPS,
+ control_time_s=EPISODE_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=make_default_teleop_action_processor(),
+ robot_action_processor=robot_ee_to_joints_processor,
+ robot_observation_processor=robot_joints_to_ee_pose_processor,
+ )
+
+ if events["rerecord_episode"]:
+ log_say("Re-record episode")
+ events["rerecord_episode"] = False
+ events["exit_early"] = False
+ dataset.clear_episode_buffer()
+ continue
+
+ # Save episode
+ dataset.save_episode()
+ episode_idx += 1
+
+# Clean up
+log_say("Stop recording")
+robot.disconnect()
+listener.stop()
+dataset.push_to_hub()
diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py
new file mode 100644
index 000000000..9ed6e51a9
--- /dev/null
+++ b/examples/so100_to_so100_EE/record.py
@@ -0,0 +1,202 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
+from lerobot.datasets.utils import combine_feature_dicts
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
+from lerobot.processor.converters import (
+ observation_to_transition,
+ robot_action_observation_to_transition,
+ transition_to_observation,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ EEBoundsAndSafety,
+ ForwardKinematicsJointsToEE,
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.scripts.lerobot_record import record_loop
+from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
+from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
+from lerobot.utils.control_utils import init_keyboard_listener
+from lerobot.utils.utils import log_say
+from lerobot.utils.visualization_utils import init_rerun
+
+NUM_EPISODES = 2
+FPS = 30
+EPISODE_TIME_SEC = 60
+RESET_TIME_SEC = 30
+TASK_DESCRIPTION = "My task description"
+HF_REPO_ID = "/"
+
+# Create the robot and teleoperator configurations
+camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
+follower_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", cameras=camera_config, use_degrees=True
+)
+leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
+
+# Initialize the robot and teleoperator
+follower = SO100Follower(follower_config)
+leader = SO100Leader(leader_config)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+follower_kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(follower.bus.motors.keys()),
+)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+leader_kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(leader.bus.motors.keys()),
+)
+
+# Build pipeline to convert follower joints to EE observation
+follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
+ steps=[
+ ForwardKinematicsJointsToEE(
+ kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
+ ),
+ ],
+ to_transition=observation_to_transition,
+ to_output=transition_to_observation,
+)
+
+# Build pipeline to convert leader joints to EE action
+leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ ForwardKinematicsJointsToEE(
+ kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Build pipeline to convert EE action to follower joints
+ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ [
+ EEBoundsAndSafety(
+ end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
+ max_ee_step_m=0.10,
+ ),
+ InverseKinematicsEEToJoints(
+ kinematics=follower_kinematics_solver,
+ motor_names=list(follower.bus.motors.keys()),
+ initial_guess_current_joints=True,
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Create the dataset
+dataset = LeRobotDataset.create(
+ repo_id=HF_REPO_ID,
+ fps=FPS,
+ features=combine_feature_dicts(
+ # Run the feature contract of the pipelines
+ # This tells you how the features would look like after the pipeline steps
+ aggregate_pipeline_dataset_features(
+ pipeline=leader_joints_to_ee,
+ initial_features=create_initial_features(action=leader.action_features),
+ use_videos=True,
+ ),
+ aggregate_pipeline_dataset_features(
+ pipeline=follower_joints_to_ee,
+ initial_features=create_initial_features(observation=follower.observation_features),
+ use_videos=True,
+ ),
+ ),
+ robot_type=follower.name,
+ use_videos=True,
+ image_writer_threads=4,
+)
+
+
+# Connect the robot and teleoperator
+leader.connect()
+follower.connect()
+
+# Initialize the keyboard listener and rerun visualization
+listener, events = init_keyboard_listener()
+init_rerun(session_name="recording_phone")
+
+if not leader.is_connected or not follower.is_connected:
+ raise ValueError("Robot or teleop is not connected!")
+
+print("Starting record loop...")
+episode_idx = 0
+while episode_idx < NUM_EPISODES and not events["stop_recording"]:
+ log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
+
+ # Main record loop
+ record_loop(
+ robot=follower,
+ events=events,
+ fps=FPS,
+ teleop=leader,
+ dataset=dataset,
+ control_time_s=EPISODE_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=leader_joints_to_ee,
+ robot_action_processor=ee_to_follower_joints,
+ robot_observation_processor=follower_joints_to_ee,
+ )
+
+ # Reset the environment if not stopping or re-recording
+ if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
+ log_say("Reset the environment")
+ record_loop(
+ robot=follower,
+ events=events,
+ fps=FPS,
+ teleop=leader,
+ control_time_s=RESET_TIME_SEC,
+ single_task=TASK_DESCRIPTION,
+ display_data=True,
+ teleop_action_processor=leader_joints_to_ee,
+ robot_action_processor=ee_to_follower_joints,
+ robot_observation_processor=follower_joints_to_ee,
+ )
+
+ if events["rerecord_episode"]:
+ log_say("Re-recording episode")
+ events["rerecord_episode"] = False
+ events["exit_early"] = False
+ dataset.clear_episode_buffer()
+ continue
+
+ # Save episode
+ dataset.save_episode()
+ episode_idx += 1
+
+# Clean up
+log_say("Stop recording")
+leader.disconnect()
+follower.disconnect()
+listener.stop()
+dataset.push_to_hub()
diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py
new file mode 100644
index 000000000..ea78d4e66
--- /dev/null
+++ b/examples/so100_to_so100_EE/replay.py
@@ -0,0 +1,101 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import time
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
+from lerobot.processor.converters import (
+ robot_action_observation_to_transition,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.utils.constants import ACTION
+from lerobot.utils.robot_utils import busy_wait
+from lerobot.utils.utils import log_say
+
+EPISODE_IDX = 0
+HF_REPO_ID = "/"
+
+# Initialize the robot config
+robot_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
+)
+
+# Initialize the robot
+robot = SO100Follower(robot_config)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(robot.bus.motors.keys()),
+)
+
+# Build pipeline to convert EE action to joints action
+robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[
+ InverseKinematicsEEToJoints(
+ kinematics=kinematics_solver,
+ motor_names=list(robot.bus.motors.keys()),
+ initial_guess_current_joints=False, # Because replay is open loop
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Fetch the dataset to replay
+dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
+# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
+episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
+actions = episode_frames.select_columns(ACTION)
+
+# Connect to the robot
+robot.connect()
+
+if not robot.is_connected:
+ raise ValueError("Robot is not connected!")
+
+print("Starting replay loop...")
+log_say(f"Replaying episode {EPISODE_IDX}")
+for idx in range(len(episode_frames)):
+ t0 = time.perf_counter()
+
+ # Get recorded action from dataset
+ ee_action = {
+ name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
+ }
+
+ # Get robot observation
+ robot_obs = robot.get_observation()
+
+ # Dataset EE -> robot joints
+ joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
+
+ # Send action to robot
+ _ = robot.send_action(joint_action)
+
+ busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
+
+# Clean up
+robot.disconnect()
diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py
new file mode 100644
index 000000000..b1a8c8c27
--- /dev/null
+++ b/examples/so100_to_so100_EE/teleoperate.py
@@ -0,0 +1,121 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
+from lerobot.processor.converters import (
+ robot_action_observation_to_transition,
+ robot_action_to_transition,
+ transition_to_robot_action,
+)
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ EEBoundsAndSafety,
+ ForwardKinematicsJointsToEE,
+ InverseKinematicsEEToJoints,
+)
+from lerobot.robots.so100_follower.so100_follower import SO100Follower
+from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
+from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
+from lerobot.utils.robot_utils import busy_wait
+from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+
+FPS = 30
+
+# Initialize the robot and teleoperator config
+follower_config = SO100FollowerConfig(
+ port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
+)
+leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
+
+# Initialize the robot and teleoperator
+follower = SO100Follower(follower_config)
+leader = SO100Leader(leader_config)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+follower_kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(follower.bus.motors.keys()),
+)
+
+# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
+leader_kinematics_solver = RobotKinematics(
+ urdf_path="./SO101/so101_new_calib.urdf",
+ target_frame_name="gripper_frame_link",
+ joint_names=list(leader.bus.motors.keys()),
+)
+
+# Build pipeline to convert teleop joints to EE action
+leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
+ steps=[
+ ForwardKinematicsJointsToEE(
+ kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
+ ),
+ ],
+ to_transition=robot_action_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# build pipeline to convert EE action to robot joints
+ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ [
+ EEBoundsAndSafety(
+ end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
+ max_ee_step_m=0.10,
+ ),
+ InverseKinematicsEEToJoints(
+ kinematics=follower_kinematics_solver,
+ motor_names=list(follower.bus.motors.keys()),
+ initial_guess_current_joints=False,
+ ),
+ ],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+)
+
+# Connect to the robot and teleoperator
+follower.connect()
+leader.connect()
+
+# Init rerun viewer
+init_rerun(session_name="so100_so100_EE_teleop")
+
+print("Starting teleop loop...")
+while True:
+ t0 = time.perf_counter()
+
+ # Get robot observation
+ robot_obs = follower.get_observation()
+
+ # Get teleop observation
+ leader_joints_obs = leader.get_action()
+
+ # teleop joints -> teleop EE action
+ leader_ee_act = leader_to_ee(leader_joints_obs)
+
+ # teleop EE -> robot joints
+ follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
+
+ # Send action to robot
+ _ = follower.send_action(follower_joints_act)
+
+ # Visualize
+ log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
+
+ busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
diff --git a/examples/3_train_policy.py b/examples/training/train_policy.py
similarity index 92%
rename from examples/3_train_policy.py
rename to examples/training/train_policy.py
index f2de79db8..16f2a4d87 100644
--- a/examples/3_train_policy.py
+++ b/examples/training/train_policy.py
@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""This script demonstrates how to train Diffusion Policy on the PushT environment.
-
-Once you have trained a model with this script, you can try to evaluate it on
-examples/2_evaluate_pretrained_policy.py
-"""
+"""This script demonstrates how to train Diffusion Policy on the PushT environment."""
from pathlib import Path
@@ -27,6 +23,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetad
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
+from lerobot.policies.factory import make_pre_post_processors
def main():
@@ -56,9 +53,10 @@ def main():
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
# We can now instantiate our policy with this config and the dataset stats.
- policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
+ policy = DiffusionPolicy(cfg)
policy.train()
policy.to(device)
+ preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
# which can differ for inputs, outputs and rewards (if there are some).
@@ -99,7 +97,7 @@ def main():
done = False
while not done:
for batch in dataloader:
- batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
+ batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
@@ -114,6 +112,8 @@ def main():
# Save a policy checkpoint.
policy.save_pretrained(output_directory)
+ preprocessor.save_pretrained(output_directory)
+ postprocessor.save_pretrained(output_directory)
if __name__ == "__main__":
diff --git a/examples/training/train_with_streaming.py b/examples/training/train_with_streaming.py
new file mode 100644
index 000000000..185be5b13
--- /dev/null
+++ b/examples/training/train_with_streaming.py
@@ -0,0 +1,108 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This script demonstrates how to train a Diffusion Policy on the PushT environment,
+using a dataset processed in streaming mode."""
+
+from pathlib import Path
+
+import torch
+
+from lerobot.configs.types import FeatureType
+from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
+from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
+from lerobot.datasets.utils import dataset_to_policy_features
+from lerobot.policies.act.configuration_act import ACTConfig
+from lerobot.policies.act.modeling_act import ACTPolicy
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.utils.constants import ACTION
+
+
+def main():
+ # Create a directory to store the training checkpoint.
+ output_directory = Path("outputs/train/example_streaming_dataset")
+ output_directory.mkdir(parents=True, exist_ok=True)
+
+ # Selects the "best" device available
+ device = (
+ torch.device("cuda")
+ if torch.cuda.is_available()
+ else torch.device("mps")
+ if torch.backends.mps.is_available()
+ else torch.device("cpu")
+ )
+ print(f"Using device: {device}")
+
+ training_steps = 10
+ log_freq = 1
+
+ dataset_id = "lerobot/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (:
+ dataset_metadata = LeRobotDatasetMetadata(dataset_id)
+ features = dataset_to_policy_features(dataset_metadata.features)
+ output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
+ input_features = {key: ft for key, ft in features.items() if key not in output_features}
+
+ # We can now instantiate our policy with this config and the dataset stats.
+ cfg = ACTConfig(input_features=input_features, output_features=output_features)
+ policy = ACTPolicy(cfg)
+ policy.train()
+ policy.to(device)
+ preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
+
+ # Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy.
+ # Here, we use delta-timestamps to only provide ground truth actions for supervision
+ delta_timestamps = {
+ ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)],
+ }
+
+ # Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched
+ # iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs
+ dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3)
+
+ optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ num_workers=4,
+ batch_size=16,
+ pin_memory=device.type != "cpu",
+ drop_last=True,
+ prefetch_factor=2, # loads batches with multiprocessing while policy trains
+ )
+
+ # Run training loop.
+ step = 0
+ done = False
+ while not done:
+ for batch in dataloader:
+ batch = preprocessor(batch)
+ loss, _ = policy.forward(batch)
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if step % log_freq == 0:
+ print(f"step: {step} loss: {loss.item():.3f}")
+ step += 1
+ if step >= training_steps:
+ done = True
+ break
+
+ # Save a policy checkpoint.
+ policy.save_pretrained(output_directory)
+ preprocessor.save_pretrained(output_directory)
+ postprocessor.save_pretrained(output_directory)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/media/hope_jr/hopejr.png b/media/hope_jr/hopejr.png
new file mode 100644
index 000000000..4186547a2
Binary files /dev/null and b/media/hope_jr/hopejr.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 9fc84d903..c67b481f0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,116 +12,221 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+[build-system]
+requires = ["setuptools"]
+build-backend = "setuptools.build_meta"
+
[project.urls]
-homepage = "https://github.com/huggingface/lerobot"
+homepage = "https://huggingface.co/lerobot"
+documentation = "https://huggingface.co/docs/lerobot/index"
+source = "https://github.com/huggingface/lerobot"
issues = "https://github.com/huggingface/lerobot/issues"
discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
-version = "0.1.0"
+version = "0.3.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
+readme = "README.md"
+license = { text = "Apache-2.0" }
+requires-python = ">=3.10"
authors = [
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
{ name = "Alexander Soare", email = "alexander.soare159@gmail.com" },
{ name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr" },
- { name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" },
- { name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" },
{ name = "Steven Palma", email = "imstevenpmwork@ieee.org" },
+ { name = "Pepijn Kooijmans", email = "pepijnkooijmans@outlook.com"},
+ { name = "Michel Aractingi", email = "michel.aractingi@gmail.com"},
+ { name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" },
+ { name = "Dana Aubakirova", email = "danaaubakirova17@gmail.com"},
+ { name = "Caroline Pascal", email = "caroline8.pascal@gmail.com"},
+ { name = "Martino Russi", email = "nopyeps@gmail.com"},
+ { name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" },
]
-readme = "README.md"
-license = { text = "Apache-2.0" }
-requires-python = ">=3.10"
-keywords = ["robotics", "deep learning", "pytorch"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
- "Topic :: Software Development :: Build Tools",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
+ "Topic :: Software Development :: Build Tools",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
]
+keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artificial intelligence"]
+
dependencies = [
- "cmake>=3.29.0.1",
- "datasets>=2.19.0",
- "deepdiff>=7.0.1",
- "diffusers>=0.27.2",
- "draccus==0.10.0",
- "einops>=0.8.0",
- "flask>=3.0.3",
- "gdown>=5.1.0",
- "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
- "h5py>=3.10.0",
- "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
- "imageio[ffmpeg]>=2.34.0",
- "jsonlines>=4.0.0",
- "numba>=0.59.0",
- "omegaconf>=2.3.0",
- "opencv-python-headless>=4.9.0",
- "packaging>=24.2",
- "av>=14.2.0",
- "pymunk>=6.6.0,<7.0.0",
- "pynput>=1.7.7",
- "pyserial>=3.5",
- "pyzmq>=26.2.1",
- "rerun-sdk>=0.21.0",
- "termcolor>=2.4.0",
- "torch>=2.2.1",
- "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
- "torchvision>=0.21.0",
- "wandb>=0.16.3",
- "zarr>=2.17.0",
+
+ # Hugging Face dependencies
+ "datasets>=4.0.0,<4.2.0",
+ "diffusers>=0.27.2,<0.36.0",
+ "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
+
+ # Core dependencies
+ "cmake>=3.29.0.1,<4.2.0",
+ "einops>=0.8.0,<0.9.0",
+ "opencv-python-headless>=4.9.0,<4.13.0",
+ "av>=14.2.0,<16.0.0",
+ "jsonlines>=4.0.0,<5.0.0",
+ "packaging>=24.2,<26.0",
+ "pynput>=1.7.7,<1.9.0",
+ "pyserial>=3.5,<4.0",
+ "wandb>=0.20.0,<0.23.0",
+
+ "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
+ "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
+ "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
+
+ "draccus==0.10.0", # TODO: Remove ==
+ "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
+ "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
+
+ # Support dependencies
+ "deepdiff>=7.0.1,<9.0.0",
+ "imageio[ffmpeg]>=2.34.0,<3.0.0",
+ "termcolor>=2.4.0,<4.0.0",
]
+# Optional dependencies
[project.optional-dependencies]
-aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
-docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
-dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
-dora = [
- "gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
-]
-dynamixel = ["dynamixel-sdk>=3.7.31"]
-feetech = ["feetech-servo-sdk>=1.0.0"]
-gamepad = ["pygame>=2.5.1", "hidapi>=0.14.0"]
-kinematics = ["placo>=0.9.6"]
-intelrealsense = [
- "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
- "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
-]
-pi0 = ["transformers>=4.50.3"]
-smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
-pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
-stretch = [
- "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
- "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
- "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
-]
-test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "pyserial>=3.5", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
-hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.9", "protobuf>=5.29.3", "grpcio==1.71.0", "placo>=0.9.6"]
-umi = ["imagecodecs>=2024.1.1"]
-video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
-xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
-[tool.poetry]
-requires-poetry = ">=2.1"
-packages = [
- { include = "lerobot", from = "src" }
+# Common
+pygame-dep = ["pygame>=2.5.1,<2.7.0"]
+placo-dep = ["placo>=0.9.6,<0.10.0"]
+transformers-dep = ["transformers>=4.53.0,<5.0.0"]
+grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
+
+# Motors
+feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
+dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
+
+# Robots
+gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
+hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
+lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
+reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
+kinematics = ["lerobot[placo-dep]"]
+intelrealsense = [
+ "pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
+ "pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
]
+phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
+# stretch = [
+# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
+# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
+# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
+# ] # TODO: Currently not supported
+
+# Policies
+pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
+smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
+hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
+
+# Features
+async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
+
+# Development
+dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
+test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
+video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
+
+# Simulation
+aloha = ["gym-aloha>=0.1.1,<0.2.0"]
+pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
+xarm = ["gym-xarm>=0.1.1,<0.2.0"]
+libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
+
+
+# All
+all = [
+ "lerobot[dynamixel]",
+ "lerobot[gamepad]",
+ "lerobot[hopejr]",
+ "lerobot[lekiwi]",
+ "lerobot[reachy2]",
+ "lerobot[kinematics]",
+ "lerobot[intelrealsense]",
+ "lerobot[pi]",
+ "lerobot[smolvla]",
+ "lerobot[hilserl]",
+ "lerobot[async]",
+ "lerobot[dev]",
+ "lerobot[test]",
+ "lerobot[video_benchmark]",
+ "lerobot[aloha]",
+ "lerobot[pusht]",
+ "lerobot[xarm]",
+ "lerobot[phone]",
+ "lerobot[libero]",
+]
+
+[project.scripts]
+lerobot-calibrate="lerobot.scripts.lerobot_calibrate:main"
+lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main"
+lerobot-find-port="lerobot.scripts.lerobot_find_port:main"
+lerobot-record="lerobot.scripts.lerobot_record:main"
+lerobot-replay="lerobot.scripts.lerobot_replay:main"
+lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
+lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
+lerobot-eval="lerobot.scripts.lerobot_eval:main"
+lerobot-train="lerobot.scripts.lerobot_train:main"
+lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
+lerobot-info="lerobot.scripts.lerobot_info:main"
+lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
+lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
+
+# ---------------- Tool Configurations ----------------
+[tool.setuptools.packages.find]
+where = ["src"]
[tool.ruff]
-line-length = 110
target-version = "py310"
+line-length = 110
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
[tool.ruff.lint]
-select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
+# E, W: pycodestyle errors and warnings
+# F: PyFlakes
+# I: isort
+# UP: pyupgrade
+# B: flake8-bugbear (good practices, potential bugs)
+# C4: flake8-comprehensions (more concise comprehensions)
+# A: flake8-builtins (shadowing builtins)
+# SIM: flake8-simplify
+# RUF: Ruff-specific rules
+# D: pydocstyle (for docstring style/formatting)
+# S: flake8-bandit (some security checks, complements Bandit)
+# T20: flake8-print (discourage print statements in production code)
+# N: pep8-naming
+# TODO: Uncomment rules when ready to use
+select = [
+ "E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF"
+]
+ignore = [
+ "E501", # Line too long
+ "T201", # Print statement found
+ "T203", # Pprint statement found
+ "B008", # Perform function call in argument defaults
+]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "F403"]
+[tool.ruff.lint.isort]
+combine-as-imports = true
+known-first-party = ["lerobot"]
+
+[tool.ruff.lint.pydocstyle]
+convention = "google"
+
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+skip-magic-trailing-comma = false
+line-ending = "auto"
+docstring-code-format = true
+
[tool.bandit]
exclude_dirs = [
"tests",
@@ -131,7 +236,7 @@ exclude_dirs = [
"src/lerobot/policies/pi0/conversion_scripts",
"src/lerobot/scripts/push_dataset_to_hub.py",
]
-skips = ["B101", "B311", "B404", "B603"]
+skips = ["B101", "B311", "B404", "B603", "B615"]
[tool.typos]
default.extend-ignore-re = [
@@ -146,6 +251,103 @@ default.extend-ignore-identifiers-re = [
"ein",
]
-[build-system]
-requires = ["poetry-core"]
-build-backend = "poetry.core.masonry.api"
+# TODO: Uncomment when ready to use
+# [tool.interrogate]
+# ignore-init-module = true
+# ignore-init-method = true
+# ignore-nested-functions = false
+# ignore-magic = false
+# ignore-semiprivate = false
+# ignore-private = false
+# ignore-property-decorators = false
+# ignore-module = false
+# ignore-setters = false
+# fail-under = 80
+# output-format = "term-missing"
+# color = true
+# paths = ["src/lerobot"]
+
+# TODO: Enable mypy gradually module by module across multiple PRs
+# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
+
+[tool.mypy]
+python_version = "3.10"
+ignore_missing_imports = true
+follow_imports = "skip"
+# warn_return_any = true
+# warn_unused_configs = true
+# strict = true
+# disallow_untyped_defs = true
+# disallow_incomplete_defs = true
+# check_untyped_defs = true
+
+[[tool.mypy.overrides]]
+module = "lerobot.*"
+ignore_errors = true
+
+[[tool.mypy.overrides]]
+module = "lerobot.envs.*"
+# Enable type checking only for the envs module
+ignore_errors = false
+
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.utils.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.configs.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.optim.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.model.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.processor.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.datasets.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.cameras.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.motors.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.robots.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.teleoperators.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.policies.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.rl.*"
+# ignore_errors = false
+
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.async_inference.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.transport.*"
+# ignore_errors = false
+
+# [[tool.mypy.overrides]]
+# module = "lerobot.scripts.*"
+# ignore_errors = false
diff --git a/requirements-macos.txt b/requirements-macos.txt
new file mode 100644
index 000000000..07e263da5
--- /dev/null
+++ b/requirements-macos.txt
@@ -0,0 +1,625 @@
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+# pip-compile --output-file=requirements-macos.txt requirements.in
+#
+-e .[all]
+ # via -[all]
+absl-py==2.3.1
+ # via
+ # dm-control
+ # dm-env
+ # dm-tree
+ # labmaze
+ # mujoco
+accelerate==1.9.0
+ # via lerobot
+aiohappyeyeballs==2.6.1
+ # via aiohttp
+aiohttp==3.12.15
+ # via fsspec
+aiosignal==1.4.0
+ # via aiohttp
+annotated-types==0.7.0
+ # via pydantic
+asttokens==3.0.0
+ # via stack-data
+async-timeout==5.0.1
+ # via aiohttp
+attrs==25.3.0
+ # via
+ # aiohttp
+ # dm-tree
+ # jsonlines
+ # rerun-sdk
+av==15.0.0
+ # via lerobot
+blinker==1.9.0
+ # via flask
+certifi==2025.7.14
+ # via
+ # requests
+ # sentry-sdk
+cffi==1.17.1
+ # via pymunk
+cfgv==3.4.0
+ # via pre-commit
+charset-normalizer==3.4.2
+ # via requests
+click==8.2.1
+ # via
+ # flask
+ # wandb
+cloudpickle==3.1.1
+ # via gymnasium
+cmake==4.0.3
+ # via lerobot
+cmeel==0.57.3
+ # via
+ # cmeel-assimp
+ # cmeel-boost
+ # cmeel-console-bridge
+ # cmeel-octomap
+ # cmeel-qhull
+ # cmeel-tinyxml2
+ # cmeel-urdfdom
+ # cmeel-zlib
+ # coal-library
+ # eigenpy
+ # eiquadprog
+ # pin
+ # placo
+ # rhoban-cmeel-jsoncpp
+cmeel-assimp==5.4.3.1
+ # via coal-library
+cmeel-boost==1.87.0.1
+ # via
+ # coal-library
+ # eigenpy
+ # eiquadprog
+ # pin
+cmeel-console-bridge==1.0.2.3
+ # via cmeel-urdfdom
+cmeel-octomap==1.10.0
+ # via coal-library
+cmeel-qhull==8.0.2.1
+ # via coal-library
+cmeel-tinyxml2==10.0.0
+ # via cmeel-urdfdom
+cmeel-urdfdom==4.0.1
+ # via pin
+cmeel-zlib==1.3.1
+ # via cmeel-assimp
+coal-library==3.0.1
+ # via pin
+contourpy==1.3.2
+ # via matplotlib
+coverage[toml]==7.10.1
+ # via pytest-cov
+cycler==0.12.1
+ # via matplotlib
+datasets==3.6.0
+ # via lerobot
+debugpy==1.8.15
+ # via lerobot
+decorator==5.2.1
+ # via ipython
+deepdiff==8.5.0
+ # via lerobot
+diffusers==0.34.0
+ # via lerobot
+dill==0.3.8
+ # via
+ # datasets
+ # multiprocess
+distlib==0.4.0
+ # via virtualenv
+dm-control==1.0.14
+ # via gym-aloha
+dm-env==1.6
+ # via dm-control
+dm-tree==0.1.9
+ # via
+ # dm-control
+ # dm-env
+docopt==0.6.2
+ # via num2words
+draccus==0.10.0
+ # via lerobot
+dynamixel-sdk==3.7.31
+ # via lerobot
+eigenpy==3.10.3
+ # via coal-library
+einops==0.8.1
+ # via lerobot
+eiquadprog==1.2.9
+ # via placo
+exceptiongroup==1.3.0
+ # via
+ # ipython
+ # pytest
+executing==2.2.0
+ # via stack-data
+farama-notifications==0.0.4
+ # via gymnasium
+feetech-servo-sdk==1.0.0
+ # via lerobot
+filelock==3.18.0
+ # via
+ # datasets
+ # diffusers
+ # huggingface-hub
+ # torch
+ # transformers
+ # virtualenv
+flask==3.1.1
+ # via lerobot
+fonttools==4.59.0
+ # via matplotlib
+frozenlist==1.7.0
+ # via
+ # aiohttp
+ # aiosignal
+fsspec[http]==2025.3.0
+ # via
+ # datasets
+ # huggingface-hub
+ # torch
+gitdb==4.0.12
+ # via gitpython
+gitpython==3.1.45
+ # via wandb
+glfw==2.9.0
+ # via
+ # dm-control
+ # mujoco
+grpcio==1.73.1
+ # via
+ # grpcio-tools
+ # lerobot
+grpcio-tools==1.73.1
+ # via lerobot
+gym-aloha==0.1.1
+ # via lerobot
+gym-hil==0.1.10
+ # via lerobot
+gym-pusht==0.1.5
+ # via lerobot
+gym-xarm==0.1.1
+ # via lerobot
+gymnasium==0.29.1
+ # via
+ # gym-aloha
+ # gym-hil
+ # gym-pusht
+ # gym-xarm
+ # gymnasium-robotics
+ # lerobot
+ # pettingzoo
+gymnasium-robotics==1.2.4
+ # via gym-xarm
+hf-transfer==0.1.9
+ # via huggingface-hub
+hf-xet==1.1.5
+ # via huggingface-hub
+hidapi==0.14.0.post4
+ # via
+ # gym-hil
+ # lerobot
+huggingface-hub[cli,hf-transfer]==0.34.3
+ # via
+ # accelerate
+ # datasets
+ # diffusers
+ # lerobot
+ # tokenizers
+ # transformers
+identify==2.6.12
+ # via pre-commit
+idna==3.10
+ # via
+ # requests
+ # yarl
+imageio[ffmpeg]==2.37.0
+ # via
+ # gym-aloha
+ # gym-hil
+ # gymnasium-robotics
+ # lerobot
+ # scikit-image
+imageio-ffmpeg==0.6.0
+ # via imageio
+importlib-metadata==8.7.0
+ # via diffusers
+iniconfig==2.1.0
+ # via pytest
+inquirerpy==0.3.4
+ # via huggingface-hub
+ipython==8.37.0
+ # via meshcat
+ischedule==1.2.7
+ # via placo
+itsdangerous==2.2.0
+ # via flask
+jedi==0.19.2
+ # via ipython
+jinja2==3.1.6
+ # via
+ # flask
+ # gymnasium-robotics
+ # torch
+jsonlines==4.0.0
+ # via lerobot
+kiwisolver==1.4.8
+ # via matplotlib
+labmaze==1.0.6
+ # via dm-control
+lazy-loader==0.4
+ # via scikit-image
+lxml==6.0.0
+ # via dm-control
+markupsafe==3.0.2
+ # via
+ # flask
+ # jinja2
+ # werkzeug
+matplotlib==3.10.5
+ # via lerobot
+matplotlib-inline==0.1.7
+ # via ipython
+mergedeep==1.3.4
+ # via draccus
+meshcat==0.3.2
+ # via placo
+mock-serial==0.0.1
+ # via lerobot
+mpmath==1.3.0
+ # via sympy
+mujoco==2.3.7
+ # via
+ # dm-control
+ # gym-aloha
+ # gym-hil
+ # gym-xarm
+ # gymnasium-robotics
+multidict==6.6.3
+ # via
+ # aiohttp
+ # yarl
+multiprocess==0.70.16
+ # via datasets
+mypy-extensions==1.1.0
+ # via typing-inspect
+networkx==3.4.2
+ # via
+ # scikit-image
+ # torch
+nodeenv==1.9.1
+ # via pre-commit
+num2words==0.5.14
+ # via lerobot
+numpy==2.2.6
+ # via
+ # accelerate
+ # cmeel-boost
+ # contourpy
+ # datasets
+ # diffusers
+ # dm-control
+ # dm-env
+ # dm-tree
+ # gymnasium
+ # gymnasium-robotics
+ # imageio
+ # labmaze
+ # matplotlib
+ # meshcat
+ # mujoco
+ # opencv-python
+ # opencv-python-headless
+ # pandas
+ # pettingzoo
+ # rerun-sdk
+ # scikit-image
+ # scipy
+ # shapely
+ # tifffile
+ # torchvision
+ # transformers
+opencv-python==4.12.0.88
+ # via gym-pusht
+opencv-python-headless==4.12.0.88
+ # via lerobot
+orderly-set==5.5.0
+ # via deepdiff
+packaging==25.0
+ # via
+ # accelerate
+ # datasets
+ # huggingface-hub
+ # lazy-loader
+ # lerobot
+ # matplotlib
+ # pytest
+ # scikit-image
+ # transformers
+ # wandb
+pandas==2.3.1
+ # via
+ # datasets
+ # lerobot
+parso==0.8.4
+ # via jedi
+pettingzoo==1.24.3
+ # via gymnasium-robotics
+pexpect==4.9.0
+ # via ipython
+pfzy==0.3.4
+ # via inquirerpy
+pillow==11.3.0
+ # via
+ # diffusers
+ # imageio
+ # matplotlib
+ # meshcat
+ # rerun-sdk
+ # scikit-image
+ # torchvision
+pin==3.4.0
+ # via placo
+placo==0.9.14
+ # via lerobot
+platformdirs==4.3.8
+ # via
+ # virtualenv
+ # wandb
+pluggy==1.6.0
+ # via
+ # pytest
+ # pytest-cov
+pre-commit==4.2.0
+ # via lerobot
+prompt-toolkit==3.0.51
+ # via
+ # inquirerpy
+ # ipython
+propcache==0.3.2
+ # via
+ # aiohttp
+ # yarl
+protobuf==6.31.0
+ # via
+ # dm-control
+ # grpcio-tools
+ # lerobot
+ # wandb
+psutil==7.0.0
+ # via
+ # accelerate
+ # imageio
+ptyprocess==0.7.0
+ # via pexpect
+pure-eval==0.2.3
+ # via stack-data
+pyarrow==21.0.0
+ # via
+ # datasets
+ # rerun-sdk
+pycparser==2.22
+ # via cffi
+pydantic==2.11.7
+ # via wandb
+pydantic-core==2.33.2
+ # via pydantic
+pygame==2.6.1
+ # via
+ # gym-hil
+ # gym-pusht
+ # lerobot
+pygments==2.19.2
+ # via
+ # ipython
+ # pytest
+pymunk==6.11.1
+ # via
+ # gym-pusht
+ # lerobot
+pyngrok==7.2.12
+ # via meshcat
+pynput==1.8.1
+ # via
+ # gym-hil
+ # lerobot
+pyobjc-core==11.1
+ # via
+ # pyobjc-framework-applicationservices
+ # pyobjc-framework-cocoa
+ # pyobjc-framework-coretext
+ # pyobjc-framework-quartz
+pyobjc-framework-applicationservices==11.1
+ # via pynput
+pyobjc-framework-cocoa==11.1
+ # via
+ # pyobjc-framework-applicationservices
+ # pyobjc-framework-coretext
+ # pyobjc-framework-quartz
+pyobjc-framework-coretext==11.1
+ # via pyobjc-framework-applicationservices
+pyobjc-framework-quartz==11.1
+ # via
+ # pynput
+ # pyobjc-framework-applicationservices
+ # pyobjc-framework-coretext
+pyopengl==3.1.9
+ # via
+ # dm-control
+ # mujoco
+pyparsing==3.2.3
+ # via
+ # dm-control
+ # matplotlib
+pyrealsense2-macosx==2.54.2
+ # via lerobot
+pyserial==3.5
+ # via
+ # dynamixel-sdk
+ # feetech-servo-sdk
+ # lerobot
+pytest==8.4.1
+ # via
+ # lerobot
+ # pytest-cov
+ # pytest-timeout
+pytest-cov==6.2.1
+ # via lerobot
+pytest-timeout==2.4.0
+ # via lerobot
+python-dateutil==2.9.0.post0
+ # via
+ # matplotlib
+ # pandas
+pytz==2025.2
+ # via pandas
+pyyaml==6.0.2
+ # via
+ # accelerate
+ # datasets
+ # draccus
+ # huggingface-hub
+ # pre-commit
+ # pyngrok
+ # pyyaml-include
+ # transformers
+ # wandb
+pyyaml-include==1.4.1
+ # via draccus
+pyzmq==27.0.0
+ # via
+ # lerobot
+ # meshcat
+regex==2025.7.34
+ # via
+ # diffusers
+ # transformers
+requests==2.32.4
+ # via
+ # datasets
+ # diffusers
+ # dm-control
+ # huggingface-hub
+ # transformers
+ # wandb
+rerun-sdk==0.22.1
+ # via lerobot
+rhoban-cmeel-jsoncpp==1.9.4.9
+ # via placo
+safetensors==0.5.3
+ # via
+ # accelerate
+ # diffusers
+ # lerobot
+ # transformers
+scikit-image==0.25.2
+ # via
+ # gym-pusht
+ # lerobot
+scipy==1.15.3
+ # via
+ # dm-control
+ # scikit-image
+sentry-sdk==2.34.1
+ # via wandb
+shapely==2.1.1
+ # via gym-pusht
+six==1.17.0
+ # via
+ # pynput
+ # python-dateutil
+smmap==5.0.2
+ # via gitdb
+stack-data==0.6.3
+ # via ipython
+sympy==1.14.0
+ # via torch
+termcolor==3.1.0
+ # via lerobot
+tifffile==2025.5.10
+ # via scikit-image
+tokenizers==0.21.4
+ # via transformers
+toml==0.10.2
+ # via draccus
+tomli==2.2.1
+ # via
+ # cmeel
+ # coverage
+ # pytest
+torch==2.7.1
+ # via
+ # accelerate
+ # lerobot
+ # torchvision
+torchcodec==0.5
+ # via lerobot
+torchvision==0.22.1
+ # via lerobot
+tornado==6.5.1
+ # via meshcat
+tqdm==4.67.1
+ # via
+ # datasets
+ # dm-control
+ # huggingface-hub
+ # transformers
+traitlets==5.14.3
+ # via
+ # ipython
+ # matplotlib-inline
+transformers==4.51.3
+ # via lerobot
+typing-extensions==4.14.1
+ # via
+ # aiosignal
+ # exceptiongroup
+ # gymnasium
+ # huggingface-hub
+ # ipython
+ # multidict
+ # pydantic
+ # pydantic-core
+ # rerun-sdk
+ # torch
+ # typing-inspect
+ # typing-inspection
+ # wandb
+typing-inspect==0.9.0
+ # via draccus
+typing-inspection==0.4.1
+ # via pydantic
+tzdata==2025.2
+ # via pandas
+u-msgpack-python==2.8.0
+ # via meshcat
+urllib3==2.5.0
+ # via
+ # requests
+ # sentry-sdk
+virtualenv==20.32.0
+ # via pre-commit
+wandb==0.21.0
+ # via lerobot
+wcwidth==0.2.13
+ # via prompt-toolkit
+werkzeug==3.1.3
+ # via flask
+wrapt==1.17.2
+ # via dm-tree
+xxhash==3.5.0
+ # via datasets
+yarl==1.20.1
+ # via aiohttp
+zipp==3.23.0
+ # via importlib-metadata
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/requirements-ubuntu.txt b/requirements-ubuntu.txt
new file mode 100644
index 000000000..af7258d67
--- /dev/null
+++ b/requirements-ubuntu.txt
@@ -0,0 +1,650 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+# pip-compile --output-file=requirements-ubuntu.txt requirements.in
+#
+-e .[all]
+ # via -[all]
+absl-py==2.3.1
+ # via
+ # dm-control
+ # dm-env
+ # dm-tree
+ # labmaze
+ # mujoco
+accelerate==1.9.0
+ # via lerobot
+aiohappyeyeballs==2.6.1
+ # via aiohttp
+aiohttp==3.12.15
+ # via fsspec
+aiosignal==1.4.0
+ # via aiohttp
+annotated-types==0.7.0
+ # via pydantic
+asttokens==3.0.0
+ # via stack-data
+async-timeout==5.0.1
+ # via aiohttp
+attrs==25.3.0
+ # via
+ # aiohttp
+ # dm-tree
+ # jsonlines
+ # rerun-sdk
+av==15.0.0
+ # via lerobot
+blinker==1.9.0
+ # via flask
+certifi==2025.7.14
+ # via
+ # requests
+ # sentry-sdk
+cffi==1.17.1
+ # via pymunk
+cfgv==3.4.0
+ # via pre-commit
+charset-normalizer==3.4.2
+ # via requests
+click==8.2.1
+ # via
+ # flask
+ # wandb
+cloudpickle==3.1.1
+ # via gymnasium
+cmake==4.0.3
+ # via lerobot
+cmeel==0.57.3
+ # via
+ # cmeel-assimp
+ # cmeel-boost
+ # cmeel-console-bridge
+ # cmeel-octomap
+ # cmeel-qhull
+ # cmeel-tinyxml2
+ # cmeel-urdfdom
+ # cmeel-zlib
+ # coal-library
+ # eigenpy
+ # eiquadprog
+ # pin
+ # placo
+ # rhoban-cmeel-jsoncpp
+cmeel-assimp==5.4.3.1
+ # via coal-library
+cmeel-boost==1.87.0.1
+ # via
+ # coal-library
+ # eigenpy
+ # eiquadprog
+ # pin
+cmeel-console-bridge==1.0.2.3
+ # via cmeel-urdfdom
+cmeel-octomap==1.10.0
+ # via coal-library
+cmeel-qhull==8.0.2.1
+ # via coal-library
+cmeel-tinyxml2==10.0.0
+ # via cmeel-urdfdom
+cmeel-urdfdom==4.0.1
+ # via pin
+cmeel-zlib==1.3.1
+ # via cmeel-assimp
+coal-library==3.0.1
+ # via pin
+contourpy==1.3.2
+ # via matplotlib
+coverage[toml]==7.10.1
+ # via pytest-cov
+cycler==0.12.1
+ # via matplotlib
+datasets==3.6.0
+ # via lerobot
+debugpy==1.8.15
+ # via lerobot
+decorator==5.2.1
+ # via ipython
+deepdiff==8.5.0
+ # via lerobot
+diffusers==0.34.0
+ # via lerobot
+dill==0.3.8
+ # via
+ # datasets
+ # multiprocess
+distlib==0.4.0
+ # via virtualenv
+dm-control==1.0.14
+ # via gym-aloha
+dm-env==1.6
+ # via dm-control
+dm-tree==0.1.9
+ # via
+ # dm-control
+ # dm-env
+docopt==0.6.2
+ # via num2words
+draccus==0.10.0
+ # via lerobot
+dynamixel-sdk==3.7.31
+ # via lerobot
+eigenpy==3.10.3
+ # via coal-library
+einops==0.8.1
+ # via lerobot
+eiquadprog==1.2.9
+ # via placo
+evdev==1.9.2
+ # via pynput
+exceptiongroup==1.3.0
+ # via
+ # ipython
+ # pytest
+executing==2.2.0
+ # via stack-data
+farama-notifications==0.0.4
+ # via gymnasium
+feetech-servo-sdk==1.0.0
+ # via lerobot
+filelock==3.18.0
+ # via
+ # datasets
+ # diffusers
+ # huggingface-hub
+ # torch
+ # transformers
+ # virtualenv
+flask==3.1.1
+ # via lerobot
+fonttools==4.59.0
+ # via matplotlib
+frozenlist==1.7.0
+ # via
+ # aiohttp
+ # aiosignal
+fsspec[http]==2025.3.0
+ # via
+ # datasets
+ # huggingface-hub
+ # torch
+gitdb==4.0.12
+ # via gitpython
+gitpython==3.1.45
+ # via wandb
+glfw==2.9.0
+ # via
+ # dm-control
+ # mujoco
+grpcio==1.73.1
+ # via
+ # grpcio-tools
+ # lerobot
+grpcio-tools==1.73.1
+ # via lerobot
+gym-aloha==0.1.1
+ # via lerobot
+gym-hil==0.1.10
+ # via lerobot
+gym-pusht==0.1.5
+ # via lerobot
+gym-xarm==0.1.1
+ # via lerobot
+gymnasium==0.29.1
+ # via
+ # gym-aloha
+ # gym-hil
+ # gym-pusht
+ # gym-xarm
+ # gymnasium-robotics
+ # lerobot
+ # pettingzoo
+gymnasium-robotics==1.2.4
+ # via gym-xarm
+hf-transfer==0.1.9
+ # via huggingface-hub
+hf-xet==1.1.5
+ # via huggingface-hub
+hidapi==0.14.0.post4
+ # via
+ # gym-hil
+ # lerobot
+huggingface-hub[cli,hf-transfer]==0.34.3
+ # via
+ # accelerate
+ # datasets
+ # diffusers
+ # lerobot
+ # tokenizers
+ # transformers
+identify==2.6.12
+ # via pre-commit
+idna==3.10
+ # via
+ # requests
+ # yarl
+imageio[ffmpeg]==2.37.0
+ # via
+ # gym-aloha
+ # gym-hil
+ # gymnasium-robotics
+ # lerobot
+ # scikit-image
+imageio-ffmpeg==0.6.0
+ # via imageio
+importlib-metadata==8.7.0
+ # via diffusers
+iniconfig==2.1.0
+ # via pytest
+inquirerpy==0.3.4
+ # via huggingface-hub
+ipython==8.37.0
+ # via meshcat
+ischedule==1.2.7
+ # via placo
+itsdangerous==2.2.0
+ # via flask
+jedi==0.19.2
+ # via ipython
+jinja2==3.1.6
+ # via
+ # flask
+ # gymnasium-robotics
+ # torch
+jsonlines==4.0.0
+ # via lerobot
+kiwisolver==1.4.8
+ # via matplotlib
+labmaze==1.0.6
+ # via dm-control
+lazy-loader==0.4
+ # via scikit-image
+lxml==6.0.0
+ # via dm-control
+markupsafe==3.0.2
+ # via
+ # flask
+ # jinja2
+ # werkzeug
+matplotlib==3.10.5
+ # via lerobot
+matplotlib-inline==0.1.7
+ # via ipython
+mergedeep==1.3.4
+ # via draccus
+meshcat==0.3.2
+ # via placo
+mock-serial==0.0.1
+ # via lerobot
+mpmath==1.3.0
+ # via sympy
+mujoco==2.3.7
+ # via
+ # dm-control
+ # gym-aloha
+ # gym-hil
+ # gym-xarm
+ # gymnasium-robotics
+multidict==6.6.3
+ # via
+ # aiohttp
+ # yarl
+multiprocess==0.70.16
+ # via datasets
+mypy-extensions==1.1.0
+ # via typing-inspect
+networkx==3.4.2
+ # via
+ # scikit-image
+ # torch
+nodeenv==1.9.1
+ # via pre-commit
+num2words==0.5.14
+ # via lerobot
+numpy==2.2.6
+ # via
+ # accelerate
+ # cmeel-boost
+ # contourpy
+ # datasets
+ # diffusers
+ # dm-control
+ # dm-env
+ # dm-tree
+ # gymnasium
+ # gymnasium-robotics
+ # imageio
+ # labmaze
+ # matplotlib
+ # meshcat
+ # mujoco
+ # opencv-python
+ # opencv-python-headless
+ # pandas
+ # pettingzoo
+ # rerun-sdk
+ # scikit-image
+ # scipy
+ # shapely
+ # tifffile
+ # torchvision
+ # transformers
+nvidia-cublas-cu12==12.6.4.1
+ # via
+ # nvidia-cudnn-cu12
+ # nvidia-cusolver-cu12
+ # torch
+nvidia-cuda-cupti-cu12==12.6.80
+ # via torch
+nvidia-cuda-nvrtc-cu12==12.6.77
+ # via torch
+nvidia-cuda-runtime-cu12==12.6.77
+ # via torch
+nvidia-cudnn-cu12==9.5.1.17
+ # via torch
+nvidia-cufft-cu12==11.3.0.4
+ # via torch
+nvidia-cufile-cu12==1.11.1.6
+ # via torch
+nvidia-curand-cu12==10.3.7.77
+ # via torch
+nvidia-cusolver-cu12==11.7.1.2
+ # via torch
+nvidia-cusparse-cu12==12.5.4.2
+ # via
+ # nvidia-cusolver-cu12
+ # torch
+nvidia-cusparselt-cu12==0.6.3
+ # via torch
+nvidia-nccl-cu12==2.26.2
+ # via torch
+nvidia-nvjitlink-cu12==12.6.85
+ # via
+ # nvidia-cufft-cu12
+ # nvidia-cusolver-cu12
+ # nvidia-cusparse-cu12
+ # torch
+nvidia-nvtx-cu12==12.6.77
+ # via torch
+opencv-python==4.12.0.88
+ # via gym-pusht
+opencv-python-headless==4.12.0.88
+ # via lerobot
+orderly-set==5.5.0
+ # via deepdiff
+packaging==25.0
+ # via
+ # accelerate
+ # datasets
+ # huggingface-hub
+ # lazy-loader
+ # lerobot
+ # matplotlib
+ # pytest
+ # scikit-image
+ # transformers
+ # wandb
+pandas==2.3.1
+ # via
+ # datasets
+ # lerobot
+parso==0.8.4
+ # via jedi
+pettingzoo==1.24.3
+ # via gymnasium-robotics
+pexpect==4.9.0
+ # via ipython
+pfzy==0.3.4
+ # via inquirerpy
+pillow==11.3.0
+ # via
+ # diffusers
+ # imageio
+ # matplotlib
+ # meshcat
+ # rerun-sdk
+ # scikit-image
+ # torchvision
+pin==3.4.0
+ # via placo
+placo==0.9.14
+ # via lerobot
+platformdirs==4.3.8
+ # via
+ # virtualenv
+ # wandb
+pluggy==1.6.0
+ # via
+ # pytest
+ # pytest-cov
+pre-commit==4.2.0
+ # via lerobot
+prompt-toolkit==3.0.51
+ # via
+ # inquirerpy
+ # ipython
+propcache==0.3.2
+ # via
+ # aiohttp
+ # yarl
+protobuf==6.31.0
+ # via
+ # dm-control
+ # grpcio-tools
+ # lerobot
+ # wandb
+psutil==7.0.0
+ # via
+ # accelerate
+ # imageio
+ptyprocess==0.7.0
+ # via pexpect
+pure-eval==0.2.3
+ # via stack-data
+pyarrow==21.0.0
+ # via
+ # datasets
+ # rerun-sdk
+pycparser==2.22
+ # via cffi
+pydantic==2.11.7
+ # via wandb
+pydantic-core==2.33.2
+ # via pydantic
+pygame==2.6.1
+ # via
+ # gym-hil
+ # gym-pusht
+ # lerobot
+pygments==2.19.2
+ # via
+ # ipython
+ # pytest
+pymunk==6.11.1
+ # via
+ # gym-pusht
+ # lerobot
+pyngrok==7.2.12
+ # via meshcat
+pynput==1.8.1
+ # via
+ # gym-hil
+ # lerobot
+pyopengl==3.1.9
+ # via
+ # dm-control
+ # mujoco
+pyparsing==3.2.3
+ # via
+ # dm-control
+ # matplotlib
+pyrealsense2==2.56.5.9235
+ # via lerobot
+pyserial==3.5
+ # via
+ # dynamixel-sdk
+ # feetech-servo-sdk
+ # lerobot
+pytest==8.4.1
+ # via
+ # lerobot
+ # pytest-cov
+ # pytest-timeout
+pytest-cov==6.2.1
+ # via lerobot
+pytest-timeout==2.4.0
+ # via lerobot
+python-dateutil==2.9.0.post0
+ # via
+ # matplotlib
+ # pandas
+python-xlib==0.33
+ # via pynput
+pytz==2025.2
+ # via pandas
+pyyaml==6.0.2
+ # via
+ # accelerate
+ # datasets
+ # draccus
+ # huggingface-hub
+ # pre-commit
+ # pyngrok
+ # pyyaml-include
+ # transformers
+ # wandb
+pyyaml-include==1.4.1
+ # via draccus
+pyzmq==27.0.0
+ # via
+ # lerobot
+ # meshcat
+regex==2025.7.34
+ # via
+ # diffusers
+ # transformers
+requests==2.32.4
+ # via
+ # datasets
+ # diffusers
+ # dm-control
+ # huggingface-hub
+ # transformers
+ # wandb
+rerun-sdk==0.22.1
+ # via lerobot
+rhoban-cmeel-jsoncpp==1.9.4.9
+ # via placo
+safetensors==0.5.3
+ # via
+ # accelerate
+ # diffusers
+ # lerobot
+ # transformers
+scikit-image==0.25.2
+ # via
+ # gym-pusht
+ # lerobot
+scipy==1.15.3
+ # via
+ # dm-control
+ # scikit-image
+sentry-sdk==2.34.1
+ # via wandb
+shapely==2.1.1
+ # via gym-pusht
+six==1.17.0
+ # via
+ # pynput
+ # python-dateutil
+ # python-xlib
+smmap==5.0.2
+ # via gitdb
+stack-data==0.6.3
+ # via ipython
+sympy==1.14.0
+ # via torch
+termcolor==3.1.0
+ # via lerobot
+tifffile==2025.5.10
+ # via scikit-image
+tokenizers==0.21.4
+ # via transformers
+toml==0.10.2
+ # via draccus
+tomli==2.2.1
+ # via
+ # cmeel
+ # coverage
+ # pytest
+torch==2.7.1
+ # via
+ # accelerate
+ # lerobot
+ # torchvision
+torchcodec==0.5
+ # via lerobot
+torchvision==0.22.1
+ # via lerobot
+tornado==6.5.1
+ # via meshcat
+tqdm==4.67.1
+ # via
+ # datasets
+ # dm-control
+ # huggingface-hub
+ # transformers
+traitlets==5.14.3
+ # via
+ # ipython
+ # matplotlib-inline
+transformers==4.51.3
+ # via lerobot
+triton==3.3.1
+ # via torch
+typing-extensions==4.14.1
+ # via
+ # aiosignal
+ # exceptiongroup
+ # gymnasium
+ # huggingface-hub
+ # ipython
+ # multidict
+ # pydantic
+ # pydantic-core
+ # rerun-sdk
+ # torch
+ # typing-inspect
+ # typing-inspection
+ # wandb
+typing-inspect==0.9.0
+ # via draccus
+typing-inspection==0.4.1
+ # via pydantic
+tzdata==2025.2
+ # via pandas
+u-msgpack-python==2.8.0
+ # via meshcat
+urllib3==2.5.0
+ # via
+ # requests
+ # sentry-sdk
+virtualenv==20.32.0
+ # via pre-commit
+wandb==0.21.0
+ # via lerobot
+wcwidth==0.2.13
+ # via prompt-toolkit
+werkzeug==3.1.3
+ # via flask
+wrapt==1.17.2
+ # via dm-tree
+xxhash==3.5.0
+ # via datasets
+yarl==1.20.1
+ # via aiohttp
+zipp==3.23.0
+ # via importlib-metadata
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools
diff --git a/requirements.in b/requirements.in
new file mode 100644
index 000000000..272f7f540
--- /dev/null
+++ b/requirements.in
@@ -0,0 +1,9 @@
+# requirements.in
+
+# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
+# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
+
+# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
+# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
+
+-e .[all]
diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py
index 38d4e8644..9d3ed1893 100644
--- a/src/lerobot/__init__.py
+++ b/src/lerobot/__init__.py
@@ -170,7 +170,7 @@ available_datasets = sorted(
# lists all available policies from `lerobot/policies`
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
-# lists all available robots from `lerobot/robot_devices/robots`
+# lists all available robots from `lerobot/robots`
available_robots = [
"koch",
"koch_bimanual",
@@ -179,13 +179,13 @@ available_robots = [
"so101",
]
-# lists all available cameras from `lerobot/robot_devices/cameras`
+# lists all available cameras from `lerobot/cameras`
available_cameras = [
"opencv",
"intelrealsense",
]
-# lists all available motors from `lerobot/robot_devices/motors`
+# lists all available motors from `lerobot/motors`
available_motors = [
"dynamixel",
"feetech",
diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py
new file mode 100644
index 000000000..24f889df1
--- /dev/null
+++ b/src/lerobot/async_inference/configs.py
@@ -0,0 +1,198 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass, field
+
+import torch
+
+from lerobot.robots.config import RobotConfig
+
+from .constants import (
+ DEFAULT_FPS,
+ DEFAULT_INFERENCE_LATENCY,
+ DEFAULT_OBS_QUEUE_TIMEOUT,
+)
+
+# Aggregate function registry for CLI usage
+AGGREGATE_FUNCTIONS = {
+ "weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
+ "latest_only": lambda old, new: new,
+ "average": lambda old, new: 0.5 * old + 0.5 * new,
+ "conservative": lambda old, new: 0.7 * old + 0.3 * new,
+}
+
+
+def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
+ """Get aggregate function by name from registry."""
+ if name not in AGGREGATE_FUNCTIONS:
+ available = list(AGGREGATE_FUNCTIONS.keys())
+ raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
+ return AGGREGATE_FUNCTIONS[name]
+
+
+@dataclass
+class PolicyServerConfig:
+ """Configuration for PolicyServer.
+
+ This class defines all configurable parameters for the PolicyServer,
+ including networking settings and action chunking specifications.
+ """
+
+ # Networking configuration
+ host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
+ port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
+
+ # Timing configuration
+ fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
+ inference_latency: float = field(
+ default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
+ )
+
+ obs_queue_timeout: float = field(
+ default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
+ )
+
+ def __post_init__(self):
+ """Validate configuration after initialization."""
+ if self.port < 1 or self.port > 65535:
+ raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
+
+ if self.environment_dt <= 0:
+ raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
+
+ if self.inference_latency < 0:
+ raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
+
+ if self.obs_queue_timeout < 0:
+ raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
+
+ @classmethod
+ def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
+ """Create a PolicyServerConfig from a dictionary."""
+ return cls(**config_dict)
+
+ @property
+ def environment_dt(self) -> float:
+ """Environment time step, in seconds"""
+ return 1 / self.fps
+
+ def to_dict(self) -> dict:
+ """Convert the configuration to a dictionary."""
+ return {
+ "host": self.host,
+ "port": self.port,
+ "fps": self.fps,
+ "environment_dt": self.environment_dt,
+ "inference_latency": self.inference_latency,
+ }
+
+
+@dataclass
+class RobotClientConfig:
+ """Configuration for RobotClient.
+
+ This class defines all configurable parameters for the RobotClient,
+ including network connection, policy settings, and control behavior.
+ """
+
+ # Policy configuration
+ policy_type: str = field(metadata={"help": "Type of policy to use"})
+ pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
+
+ # Robot configuration (for CLI usage - robot instance will be created from this)
+ robot: RobotConfig = field(metadata={"help": "Robot configuration"})
+
+ # Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
+ # would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
+ actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
+
+ # Task instruction for the robot to execute (e.g., 'fold my tshirt')
+ task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
+
+ # Network configuration
+ server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
+
+ # Device configuration
+ policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
+
+ # Control behavior configuration
+ chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
+ fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
+
+ # Aggregate function configuration (CLI-compatible)
+ aggregate_fn_name: str = field(
+ default="weighted_average",
+ metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
+ )
+
+ # Debug configuration
+ debug_visualize_queue_size: bool = field(
+ default=False, metadata={"help": "Visualize the action queue size"}
+ )
+
+ # Verification configuration
+ verify_robot_cameras: bool = field(
+ default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
+ )
+
+ @property
+ def environment_dt(self) -> float:
+ """Environment time step, in seconds"""
+ return 1 / self.fps
+
+ def __post_init__(self):
+ """Validate configuration after initialization."""
+ if not self.server_address:
+ raise ValueError("server_address cannot be empty")
+
+ if not self.policy_type:
+ raise ValueError("policy_type cannot be empty")
+
+ if not self.pretrained_name_or_path:
+ raise ValueError("pretrained_name_or_path cannot be empty")
+
+ if not self.policy_device:
+ raise ValueError("policy_device cannot be empty")
+
+ if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
+ raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
+
+ if self.fps <= 0:
+ raise ValueError(f"fps must be positive, got {self.fps}")
+
+ if self.actions_per_chunk <= 0:
+ raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
+
+ self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
+
+ @classmethod
+ def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
+ """Create a RobotClientConfig from a dictionary."""
+ return cls(**config_dict)
+
+ def to_dict(self) -> dict:
+ """Convert the configuration to a dictionary."""
+ return {
+ "server_address": self.server_address,
+ "policy_type": self.policy_type,
+ "pretrained_name_or_path": self.pretrained_name_or_path,
+ "policy_device": self.policy_device,
+ "chunk_size_threshold": self.chunk_size_threshold,
+ "fps": self.fps,
+ "actions_per_chunk": self.actions_per_chunk,
+ "task": self.task,
+ "debug_visualize_queue_size": self.debug_visualize_queue_size,
+ "aggregate_fn_name": self.aggregate_fn_name,
+ }
diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py
new file mode 100644
index 000000000..1b1dac0f5
--- /dev/null
+++ b/src/lerobot/async_inference/constants.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Client side: The environment evolves with a time resolution equal to 1/fps"""
+
+DEFAULT_FPS = 30
+
+"""Server side: Running inference on (at most) 1/fps"""
+DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
+
+"""Server side: Timeout for observation queue in seconds"""
+DEFAULT_OBS_QUEUE_TIMEOUT = 2
+
+# All action chunking policies
+SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
+
+# TODO: Add all other robots
+SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]
diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py
new file mode 100644
index 000000000..54fad8c54
--- /dev/null
+++ b/src/lerobot/async_inference/helpers.py
@@ -0,0 +1,304 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import logging.handlers
+import os
+import time
+from dataclasses import dataclass
+from pathlib import Path
+
+import torch
+
+from lerobot.configs.types import PolicyFeature
+from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
+
+# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
+from lerobot.policies import ( # noqa: F401
+ ACTConfig,
+ DiffusionConfig,
+ PI0Config,
+ PI05Config,
+ SmolVLAConfig,
+ VQBeTConfig,
+)
+from lerobot.robots.robot import Robot
+from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
+from lerobot.utils.utils import init_logging
+
+Action = torch.Tensor
+
+# observation as received from the robot
+RawObservation = dict[str, torch.Tensor]
+
+# observation as those recorded in LeRobot dataset (keys are different)
+LeRobotObservation = dict[str, torch.Tensor]
+
+# observation, ready for policy inference (image keys resized)
+Observation = dict[str, torch.Tensor]
+
+
+def visualize_action_queue_size(action_queue_size: list[int]) -> None:
+ import matplotlib.pyplot as plt
+
+ _, ax = plt.subplots()
+ ax.set_title("Action Queue Size Over Time")
+ ax.set_xlabel("Environment steps")
+ ax.set_ylabel("Action Queue Size")
+ ax.set_ylim(0, max(action_queue_size) * 1.1)
+ ax.grid(True, alpha=0.3)
+ ax.plot(range(len(action_queue_size)), action_queue_size)
+ plt.show()
+
+
+def validate_robot_cameras_for_policy(
+ lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
+) -> None:
+ image_keys = list(filter(is_image_key, lerobot_observation_features))
+ assert set(image_keys) == set(policy_image_features.keys()), (
+ f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
+ )
+
+
+def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
+ return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
+
+
+def is_image_key(k: str) -> bool:
+ return k.startswith(OBS_IMAGES)
+
+
+def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
+ assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
+ # (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
+ image = image.permute(2, 0, 1)
+ dims = (resize_dims[1], resize_dims[2])
+ # Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
+ image_batched = image.unsqueeze(0)
+ # Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
+ resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
+
+ return resized.squeeze(0)
+
+
+# TODO(Steven): Consider implementing a pipeline step for this
+def raw_observation_to_observation(
+ raw_observation: RawObservation,
+ lerobot_features: dict[str, dict],
+ policy_image_features: dict[str, PolicyFeature],
+) -> Observation:
+ observation = {}
+
+ observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
+ for k, v in observation.items():
+ if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
+ if "image" in k:
+ # Policy expects images in shape (B, C, H, W)
+ observation[k] = prepare_image(v).unsqueeze(0)
+ else:
+ observation[k] = v
+
+ return observation
+
+
+def prepare_image(image: torch.Tensor) -> torch.Tensor:
+ """Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
+ image = image.type(torch.float32) / 255
+ image = image.contiguous()
+
+ return image
+
+
+def extract_state_from_raw_observation(
+ lerobot_obs: RawObservation,
+) -> torch.Tensor:
+ """Extract the state from a raw observation."""
+ state = torch.tensor(lerobot_obs[OBS_STATE])
+
+ if state.ndim == 1:
+ state = state.unsqueeze(0)
+
+ return state
+
+
+def extract_images_from_raw_observation(
+ lerobot_obs: RawObservation,
+ camera_key: str,
+) -> dict[str, torch.Tensor]:
+ """Extract the images from a raw observation."""
+ return torch.tensor(lerobot_obs[camera_key])
+
+
+def make_lerobot_observation(
+ robot_obs: RawObservation,
+ lerobot_features: dict[str, dict],
+) -> LeRobotObservation:
+ """Make a lerobot observation from a raw observation."""
+ return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
+
+
+def prepare_raw_observation(
+ robot_obs: RawObservation,
+ lerobot_features: dict[str, dict],
+ policy_image_features: dict[str, PolicyFeature],
+) -> Observation:
+ """Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
+ policy_image_features)."""
+ # 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
+ # -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
+ lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
+
+ # 2. Greps all observation.images.<> keys
+ image_keys = list(filter(is_image_key, lerobot_obs))
+ # state's shape is expected as (B, state_dim)
+ state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
+ image_dict = {
+ image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
+ }
+
+ # Turns the image features to (C, H, W) with H, W matching the policy image features.
+ # This reduces the resolution of the images
+ image_dict = {
+ key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
+ for key in image_keys
+ }
+
+ if "task" in robot_obs:
+ state_dict["task"] = robot_obs["task"]
+
+ return {**state_dict, **image_dict}
+
+
+def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
+ """
+ Get a logger using the standardized logging setup from utils.py.
+
+ Args:
+ name: Logger name (e.g., 'policy_server', 'robot_client')
+ log_to_file: Whether to also log to a file
+
+ Returns:
+ Configured logger instance
+ """
+ # Create logs directory if logging to file
+ if log_to_file:
+ os.makedirs("logs", exist_ok=True)
+ log_file = Path(f"logs/{name}_{int(time.time())}.log")
+ else:
+ log_file = None
+
+ # Initialize the standardized logging
+ init_logging(log_file=log_file, display_pid=False)
+
+ # Return a named logger
+ return logging.getLogger(name)
+
+
+@dataclass
+class TimedData:
+ """A data object with timestamp and timestep information.
+
+ Args:
+ timestamp: Unix timestamp relative to data's creation.
+ data: The actual data to wrap a timestamp around.
+ timestep: The timestep of the data.
+ """
+
+ timestamp: float
+ timestep: int
+
+ def get_timestamp(self):
+ return self.timestamp
+
+ def get_timestep(self):
+ return self.timestep
+
+
+@dataclass
+class TimedAction(TimedData):
+ action: Action
+
+ def get_action(self):
+ return self.action
+
+
+@dataclass
+class TimedObservation(TimedData):
+ observation: RawObservation
+ must_go: bool = False
+
+ def get_observation(self):
+ return self.observation
+
+
+@dataclass
+class FPSTracker:
+ """Utility class to track FPS metrics over time."""
+
+ target_fps: float
+ first_timestamp: float = None
+ total_obs_count: int = 0
+
+ def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
+ """Calculate average FPS vs target"""
+ self.total_obs_count += 1
+
+ # Initialize first observation time
+ if self.first_timestamp is None:
+ self.first_timestamp = current_timestamp
+
+ # Calculate overall average FPS (since start)
+ total_duration = current_timestamp - self.first_timestamp
+ avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
+
+ return {"avg_fps": avg_fps, "target_fps": self.target_fps}
+
+ def reset(self):
+ """Reset the FPS tracker state"""
+ self.first_timestamp = None
+ self.total_obs_count = 0
+
+
+@dataclass
+class RemotePolicyConfig:
+ policy_type: str
+ pretrained_name_or_path: str
+ lerobot_features: dict[str, PolicyFeature]
+ actions_per_chunk: int
+ device: str = "cpu"
+
+
+def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
+ """Check if two observation states are similar, under a tolerance threshold"""
+ return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
+
+
+def observations_similar(
+ obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
+) -> bool:
+ """Check if two observations are similar, under a tolerance threshold. Measures distance between
+ observations as the difference in joint-space between the two observations.
+
+ NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
+ An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
+ to surpass this joint-space similarity check.
+ """
+ obs1_state = extract_state_from_raw_observation(
+ make_lerobot_observation(obs1.get_observation(), lerobot_features)
+ )
+ obs2_state = extract_state_from_raw_observation(
+ make_lerobot_observation(obs2.get_observation(), lerobot_features)
+ )
+
+ return _compare_observation_states(obs1_state, obs2_state, atol=atol)
diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py
new file mode 100644
index 000000000..f7e00dea4
--- /dev/null
+++ b/src/lerobot/async_inference/policy_server.py
@@ -0,0 +1,436 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Example:
+```shell
+python -m lerobot.async_inference.policy_server \
+ --host=127.0.0.1 \
+ --port=8080 \
+ --fps=30 \
+ --inference_latency=0.033 \
+ --obs_queue_timeout=1
+```
+"""
+
+import logging
+import pickle # nosec
+import threading
+import time
+from concurrent import futures
+from dataclasses import asdict
+from pprint import pformat
+from queue import Empty, Queue
+from typing import Any
+
+import draccus
+import grpc
+import torch
+
+from lerobot.policies.factory import get_policy_class, make_pre_post_processors
+from lerobot.processor import (
+ PolicyAction,
+ PolicyProcessorPipeline,
+)
+from lerobot.transport import (
+ services_pb2, # type: ignore
+ services_pb2_grpc, # type: ignore
+)
+from lerobot.transport.utils import receive_bytes_in_chunks
+
+from .configs import PolicyServerConfig
+from .constants import SUPPORTED_POLICIES
+from .helpers import (
+ FPSTracker,
+ Observation,
+ RemotePolicyConfig,
+ TimedAction,
+ TimedObservation,
+ get_logger,
+ observations_similar,
+ raw_observation_to_observation,
+)
+
+
+class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
+ prefix = "policy_server"
+ logger = get_logger(prefix)
+
+ def __init__(self, config: PolicyServerConfig):
+ self.config = config
+ self.shutdown_event = threading.Event()
+
+ # FPS measurement
+ self.fps_tracker = FPSTracker(target_fps=config.fps)
+
+ self.observation_queue = Queue(maxsize=1)
+
+ self._predicted_timesteps_lock = threading.Lock()
+ self._predicted_timesteps = set()
+
+ self.last_processed_obs = None
+
+ # Attributes will be set by SendPolicyInstructions
+ self.device = None
+ self.policy_type = None
+ self.lerobot_features = None
+ self.actions_per_chunk = None
+ self.policy = None
+ self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
+ self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
+
+ @property
+ def running(self):
+ return not self.shutdown_event.is_set()
+
+ @property
+ def policy_image_features(self):
+ return self.policy.config.image_features
+
+ def _reset_server(self) -> None:
+ """Flushes server state when new client connects."""
+ # only running inference on the latest observation received by the server
+ self.shutdown_event.set()
+ self.observation_queue = Queue(maxsize=1)
+
+ with self._predicted_timesteps_lock:
+ self._predicted_timesteps = set()
+
+ def Ready(self, request, context): # noqa: N802
+ client_id = context.peer()
+ self.logger.info(f"Client {client_id} connected and ready")
+ self._reset_server()
+ self.shutdown_event.clear()
+
+ return services_pb2.Empty()
+
+ def SendPolicyInstructions(self, request, context): # noqa: N802
+ """Receive policy instructions from the robot client"""
+
+ if not self.running:
+ self.logger.warning("Server is not running. Ignoring policy instructions.")
+ return services_pb2.Empty()
+
+ client_id = context.peer()
+
+ policy_specs = pickle.loads(request.data) # nosec
+
+ if not isinstance(policy_specs, RemotePolicyConfig):
+ raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
+
+ if policy_specs.policy_type not in SUPPORTED_POLICIES:
+ raise ValueError(
+ f"Policy type {policy_specs.policy_type} not supported. "
+ f"Supported policies: {SUPPORTED_POLICIES}"
+ )
+
+ self.logger.info(
+ f"Receiving policy instructions from {client_id} | "
+ f"Policy type: {policy_specs.policy_type} | "
+ f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
+ f"Actions per chunk: {policy_specs.actions_per_chunk} | "
+ f"Device: {policy_specs.device}"
+ )
+
+ self.device = policy_specs.device
+ self.policy_type = policy_specs.policy_type # act, pi0, etc.
+ self.lerobot_features = policy_specs.lerobot_features
+ self.actions_per_chunk = policy_specs.actions_per_chunk
+
+ policy_class = get_policy_class(self.policy_type)
+
+ start = time.perf_counter()
+ self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
+ self.policy.to(self.device)
+
+ # Load preprocessor and postprocessor, overriding device to match requested device
+ device_override = {"device": self.device}
+ self.preprocessor, self.postprocessor = make_pre_post_processors(
+ self.policy.config,
+ pretrained_path=policy_specs.pretrained_name_or_path,
+ preprocessor_overrides={"device_processor": device_override},
+ postprocessor_overrides={"device_processor": device_override},
+ )
+
+ end = time.perf_counter()
+
+ self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
+
+ return services_pb2.Empty()
+
+ def SendObservations(self, request_iterator, context): # noqa: N802
+ """Receive observations from the robot client"""
+ client_id = context.peer()
+ self.logger.debug(f"Receiving observations from {client_id}")
+
+ receive_time = time.time() # comparing timestamps so need time.time()
+ start_deserialize = time.perf_counter()
+ received_bytes = receive_bytes_in_chunks(
+ request_iterator, None, self.shutdown_event, self.logger
+ ) # blocking call while looping over request_iterator
+ timed_observation = pickle.loads(received_bytes) # nosec
+ deserialize_time = time.perf_counter() - start_deserialize
+
+ self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
+
+ obs_timestep = timed_observation.get_timestep()
+ obs_timestamp = timed_observation.get_timestamp()
+
+ # Calculate FPS metrics
+ fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
+
+ self.logger.debug(
+ f"Received observation #{obs_timestep} | "
+ f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
+ f"Target: {fps_metrics['target_fps']:.2f} | "
+ f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
+ )
+
+ self.logger.debug(
+ f"Server timestamp: {receive_time:.6f} | "
+ f"Client timestamp: {obs_timestamp:.6f} | "
+ f"Deserialization time: {deserialize_time:.6f}s"
+ )
+
+ if not self._enqueue_observation(
+ timed_observation # wrapping a RawObservation
+ ):
+ self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
+
+ return services_pb2.Empty()
+
+ def GetActions(self, request, context): # noqa: N802
+ """Returns actions to the robot client. Actions are sent as a single
+ chunk, containing multiple actions."""
+ client_id = context.peer()
+ self.logger.debug(f"Client {client_id} connected for action streaming")
+
+ # Generate action based on the most recent observation and its timestep
+ try:
+ getactions_starts = time.perf_counter()
+ obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
+ self.logger.info(
+ f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
+ )
+
+ with self._predicted_timesteps_lock:
+ self._predicted_timesteps.add(obs.get_timestep())
+
+ start_time = time.perf_counter()
+ action_chunk = self._predict_action_chunk(obs)
+ inference_time = time.perf_counter() - start_time
+
+ start_time = time.perf_counter()
+ actions_bytes = pickle.dumps(action_chunk) # nosec
+ serialize_time = time.perf_counter() - start_time
+
+ # Create and return the action chunk
+ actions = services_pb2.Actions(data=actions_bytes)
+
+ self.logger.info(
+ f"Action chunk #{obs.get_timestep()} generated | "
+ f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
+ )
+
+ self.logger.debug(
+ f"Action chunk #{obs.get_timestep()} generated | "
+ f"Inference time: {inference_time:.2f}s |"
+ f"Serialize time: {serialize_time:.2f}s |"
+ f"Total time: {inference_time + serialize_time:.2f}s"
+ )
+
+ time.sleep(
+ max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
+ ) # sleep controls inference latency
+
+ return actions
+
+ except Empty: # no observation added to queue in obs_queue_timeout
+ return services_pb2.Empty()
+
+ except Exception as e:
+ self.logger.error(f"Error in StreamActions: {e}")
+
+ return services_pb2.Empty()
+
+ def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
+ """Check if the observation is valid to be processed by the policy"""
+ with self._predicted_timesteps_lock:
+ predicted_timesteps = self._predicted_timesteps
+
+ if obs.get_timestep() in predicted_timesteps:
+ self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
+ return False
+
+ elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
+ self.logger.debug(
+ f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
+ )
+ return False
+
+ else:
+ return True
+
+ def _enqueue_observation(self, obs: TimedObservation) -> bool:
+ """Enqueue an observation if it must go through processing, otherwise skip it.
+ Observations not in queue are never run through the policy network"""
+
+ if (
+ obs.must_go
+ or self.last_processed_obs is None
+ or self._obs_sanity_checks(obs, self.last_processed_obs)
+ ):
+ last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
+ self.logger.debug(
+ f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
+ )
+
+ # If queue is full, get the old observation to make room
+ if self.observation_queue.full():
+ # pops from queue
+ _ = self.observation_queue.get_nowait()
+ self.logger.debug("Observation queue was full, removed oldest observation")
+
+ # Now put the new observation (never blocks as queue is non-full here)
+ self.observation_queue.put(obs)
+ return True
+
+ return False
+
+ def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
+ """Turn a chunk of actions into a list of TimedAction instances,
+ with the first action corresponding to t_0 and the rest corresponding to
+ t_0 + i*environment_dt for i in range(len(action_chunk))
+ """
+ return [
+ TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
+ for i, action in enumerate(action_chunk)
+ ]
+
+ def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
+ """Get an action chunk from the policy. The chunk contains only"""
+ chunk = self.policy.predict_action_chunk(observation)
+ if chunk.ndim != 3:
+ chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
+
+ return chunk[:, : self.actions_per_chunk, :]
+
+ def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
+ """Predict an action chunk based on an observation.
+
+ Pipeline:
+ 1. Convert raw observation to LeRobot format
+ 2. Apply preprocessor (tokenization, normalization, batching, device placement)
+ 3. Run policy inference to get action chunk
+ 4. Apply postprocessor (unnormalization, device movement)
+ 5. Convert to TimedAction list
+ """
+ """1. Prepare observation"""
+ start_prepare = time.perf_counter()
+ observation: Observation = raw_observation_to_observation(
+ observation_t.get_observation(),
+ self.lerobot_features,
+ self.policy_image_features,
+ )
+ prepare_time = time.perf_counter() - start_prepare
+
+ """2. Apply preprocessor"""
+ start_preprocess = time.perf_counter()
+ observation = self.preprocessor(observation)
+ self.last_processed_obs: TimedObservation = observation_t
+ preprocessing_time = time.perf_counter() - start_preprocess
+
+ """3. Get action chunk"""
+ start_inference = time.perf_counter()
+ action_tensor = self._get_action_chunk(observation)
+ inference_time = time.perf_counter() - start_inference
+ self.logger.info(
+ f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
+ )
+
+ """4. Apply postprocessor"""
+ # Apply postprocessor (handles unnormalization and device movement)
+ # Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
+ # So we process each action in the chunk individually
+ start_postprocess = time.perf_counter()
+ _, chunk_size, _ = action_tensor.shape
+
+ # Process each action in the chunk
+ processed_actions = []
+ for i in range(chunk_size):
+ # Extract action at timestep i: (B, action_dim)
+ single_action = action_tensor[:, i, :]
+ processed_action = self.postprocessor(single_action)
+ processed_actions.append(processed_action)
+
+ # Stack back to (B, chunk_size, action_dim), then remove batch dim
+ action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
+ self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
+
+ """5. Convert to TimedAction list"""
+ action_chunk = self._time_action_chunk(
+ observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
+ )
+ postprocess_stops = time.perf_counter()
+ postprocessing_time = postprocess_stops - start_postprocess
+
+ self.logger.info(
+ f"Observation {observation_t.get_timestep()} | "
+ f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
+ )
+
+ self.logger.debug(
+ f"Observation {observation_t.get_timestep()} | "
+ f"Prepare time: {1000 * prepare_time:.2f}ms | "
+ f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
+ f"Inference time: {1000 * inference_time:.2f}ms | "
+ f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
+ f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
+ )
+
+ return action_chunk
+
+ def stop(self):
+ """Stop the server"""
+ self._reset_server()
+ self.logger.info("Server stopping...")
+
+
+@draccus.wrap()
+def serve(cfg: PolicyServerConfig):
+ """Start the PolicyServer with the given configuration.
+
+ Args:
+ config: PolicyServerConfig instance. If None, uses default configuration.
+ """
+ logging.info(pformat(asdict(cfg)))
+
+ # Create the server instance first
+ policy_server = PolicyServer(cfg)
+
+ # Setup and start gRPC server
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
+ services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
+ server.add_insecure_port(f"{cfg.host}:{cfg.port}")
+
+ policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
+ server.start()
+
+ server.wait_for_termination()
+
+ policy_server.logger.info("Server terminated")
+
+
+if __name__ == "__main__":
+ serve()
diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py
new file mode 100644
index 000000000..8c4425c6b
--- /dev/null
+++ b/src/lerobot/async_inference/robot_client.py
@@ -0,0 +1,509 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Example command:
+```shell
+python src/lerobot/async_inference/robot_client.py \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58760431541 \
+ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
+ --robot.id=black \
+ --task="dummy" \
+ --server_address=127.0.0.1:8080 \
+ --policy_type=act \
+ --pretrained_name_or_path=user/model \
+ --policy_device=mps \
+ --actions_per_chunk=50 \
+ --chunk_size_threshold=0.5 \
+ --aggregate_fn_name=weighted_average \
+ --debug_visualize_queue_size=True
+```
+"""
+
+import logging
+import pickle # nosec
+import threading
+import time
+from collections.abc import Callable
+from dataclasses import asdict
+from pprint import pformat
+from queue import Queue
+from typing import Any
+
+import draccus
+import grpc
+import torch
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
+from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.robots import ( # noqa: F401
+ Robot,
+ RobotConfig,
+ bi_so100_follower,
+ koch_follower,
+ make_robot_from_config,
+ so100_follower,
+ so101_follower,
+)
+from lerobot.transport import (
+ services_pb2, # type: ignore
+ services_pb2_grpc, # type: ignore
+)
+from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
+
+from .configs import RobotClientConfig
+from .constants import SUPPORTED_ROBOTS
+from .helpers import (
+ Action,
+ FPSTracker,
+ Observation,
+ RawObservation,
+ RemotePolicyConfig,
+ TimedAction,
+ TimedObservation,
+ get_logger,
+ map_robot_keys_to_lerobot_features,
+ validate_robot_cameras_for_policy,
+ visualize_action_queue_size,
+)
+
+
+class RobotClient:
+ prefix = "robot_client"
+ logger = get_logger(prefix)
+
+ def __init__(self, config: RobotClientConfig):
+ """Initialize RobotClient with unified configuration.
+
+ Args:
+ config: RobotClientConfig containing all configuration parameters
+ """
+ # Store configuration
+ self.config = config
+ self.robot = make_robot_from_config(config.robot)
+ self.robot.connect()
+
+ lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
+
+ if config.verify_robot_cameras:
+ # Load policy config for validation
+ policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
+ policy_image_features = policy_config.image_features
+
+ # The cameras specified for inference must match the one supported by the policy chosen
+ validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
+
+ # Use environment variable if server_address is not provided in config
+ self.server_address = config.server_address
+
+ self.policy_config = RemotePolicyConfig(
+ config.policy_type,
+ config.pretrained_name_or_path,
+ lerobot_features,
+ config.actions_per_chunk,
+ config.policy_device,
+ )
+ self.channel = grpc.insecure_channel(
+ self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
+ )
+ self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
+ self.logger.info(f"Initializing client to connect to server at {self.server_address}")
+
+ self.shutdown_event = threading.Event()
+
+ # Initialize client side variables
+ self.latest_action_lock = threading.Lock()
+ self.latest_action = -1
+ self.action_chunk_size = -1
+
+ self._chunk_size_threshold = config.chunk_size_threshold
+
+ self.action_queue = Queue()
+ self.action_queue_lock = threading.Lock() # Protect queue operations
+ self.action_queue_size = []
+ self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
+
+ # FPS measurement
+ self.fps_tracker = FPSTracker(target_fps=self.config.fps)
+
+ self.logger.info("Robot connected and ready")
+
+ # Use an event for thread-safe coordination
+ self.must_go = threading.Event()
+ self.must_go.set() # Initially set - observations qualify for direct processing
+
+ @property
+ def running(self):
+ return not self.shutdown_event.is_set()
+
+ def start(self):
+ """Start the robot client and connect to the policy server"""
+ try:
+ # client-server handshake
+ start_time = time.perf_counter()
+ self.stub.Ready(services_pb2.Empty())
+ end_time = time.perf_counter()
+ self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
+
+ # send policy instructions
+ policy_config_bytes = pickle.dumps(self.policy_config)
+ policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
+
+ self.logger.info("Sending policy instructions to policy server")
+ self.logger.debug(
+ f"Policy type: {self.policy_config.policy_type} | "
+ f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
+ f"Device: {self.policy_config.device}"
+ )
+
+ self.stub.SendPolicyInstructions(policy_setup)
+
+ self.shutdown_event.clear()
+
+ return True
+
+ except grpc.RpcError as e:
+ self.logger.error(f"Failed to connect to policy server: {e}")
+ return False
+
+ def stop(self):
+ """Stop the robot client"""
+ self.shutdown_event.set()
+
+ self.robot.disconnect()
+ self.logger.debug("Robot disconnected")
+
+ self.channel.close()
+ self.logger.debug("Client stopped, channel closed")
+
+ def send_observation(
+ self,
+ obs: TimedObservation,
+ ) -> bool:
+ """Send observation to the policy server.
+ Returns True if the observation was sent successfully, False otherwise."""
+ if not self.running:
+ raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
+
+ if not isinstance(obs, TimedObservation):
+ raise ValueError("Input observation needs to be a TimedObservation!")
+
+ start_time = time.perf_counter()
+ observation_bytes = pickle.dumps(obs)
+ serialize_time = time.perf_counter() - start_time
+ self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
+
+ try:
+ observation_iterator = send_bytes_in_chunks(
+ observation_bytes,
+ services_pb2.Observation,
+ log_prefix="[CLIENT] Observation",
+ silent=True,
+ )
+ _ = self.stub.SendObservations(observation_iterator)
+ obs_timestep = obs.get_timestep()
+ self.logger.debug(f"Sent observation #{obs_timestep} | ")
+
+ return True
+
+ except grpc.RpcError as e:
+ self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
+ return False
+
+ def _inspect_action_queue(self):
+ with self.action_queue_lock:
+ queue_size = self.action_queue.qsize()
+ timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
+ self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
+ return queue_size, timestamps
+
+ def _aggregate_action_queues(
+ self,
+ incoming_actions: list[TimedAction],
+ aggregate_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
+ ):
+ """Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
+ if aggregate_fn is None:
+ # default aggregate function: take the latest action
+ def aggregate_fn(x1, x2):
+ return x2
+
+ future_action_queue = Queue()
+ with self.action_queue_lock:
+ internal_queue = self.action_queue.queue
+
+ current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
+
+ for new_action in incoming_actions:
+ with self.latest_action_lock:
+ latest_action = self.latest_action
+
+ # New action is older than the latest action in the queue, skip it
+ if new_action.get_timestep() <= latest_action:
+ continue
+
+ # If the new action's timestep is not in the current action queue, add it directly
+ elif new_action.get_timestep() not in current_action_queue:
+ future_action_queue.put(new_action)
+ continue
+
+ # If the new action's timestep is in the current action queue, aggregate it
+ # TODO: There is probably a way to do this with broadcasting of the two action tensors
+ future_action_queue.put(
+ TimedAction(
+ timestamp=new_action.get_timestamp(),
+ timestep=new_action.get_timestep(),
+ action=aggregate_fn(
+ current_action_queue[new_action.get_timestep()], new_action.get_action()
+ ),
+ )
+ )
+
+ with self.action_queue_lock:
+ self.action_queue = future_action_queue
+
+ def receive_actions(self, verbose: bool = False):
+ """Receive actions from the policy server"""
+ # Wait at barrier for synchronized start
+ self.start_barrier.wait()
+ self.logger.info("Action receiving thread starting")
+
+ while self.running:
+ try:
+ # Use StreamActions to get a stream of actions from the server
+ actions_chunk = self.stub.GetActions(services_pb2.Empty())
+ if len(actions_chunk.data) == 0:
+ continue # received `Empty` from server, wait for next call
+
+ receive_time = time.time()
+
+ # Deserialize bytes back into list[TimedAction]
+ deserialize_start = time.perf_counter()
+ timed_actions = pickle.loads(actions_chunk.data) # nosec
+ deserialize_time = time.perf_counter() - deserialize_start
+
+ self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
+
+ # Calculate network latency if we have matching observations
+ if len(timed_actions) > 0 and verbose:
+ with self.latest_action_lock:
+ latest_action = self.latest_action
+
+ self.logger.debug(f"Current latest action: {latest_action}")
+
+ # Get queue state before changes
+ old_size, old_timesteps = self._inspect_action_queue()
+ if not old_timesteps:
+ old_timesteps = [latest_action] # queue was empty
+
+ # Log incoming actions
+ incoming_timesteps = [a.get_timestep() for a in timed_actions]
+
+ first_action_timestep = timed_actions[0].get_timestep()
+ server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
+
+ self.logger.info(
+ f"Received action chunk for step #{first_action_timestep} | "
+ f"Latest action: #{latest_action} | "
+ f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
+ f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
+ f"Deserialization time: {deserialize_time * 1000:.2f}ms"
+ )
+
+ # Update action queue
+ start_time = time.perf_counter()
+ self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
+ queue_update_time = time.perf_counter() - start_time
+
+ self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
+
+ if verbose:
+ # Get queue state after changes
+ new_size, new_timesteps = self._inspect_action_queue()
+
+ with self.latest_action_lock:
+ latest_action = self.latest_action
+
+ self.logger.info(
+ f"Latest action: {latest_action} | "
+ f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
+ f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
+ f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
+ )
+ self.logger.debug(
+ f"Queue update complete ({queue_update_time:.6f}s) | "
+ f"Before: {old_size} items | "
+ f"After: {new_size} items | "
+ )
+
+ except grpc.RpcError as e:
+ self.logger.error(f"Error receiving actions: {e}")
+
+ def actions_available(self):
+ """Check if there are actions available in the queue"""
+ with self.action_queue_lock:
+ return not self.action_queue.empty()
+
+ def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
+ action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
+ return action
+
+ def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
+ """Reading and performing actions in local queue"""
+
+ # Lock only for queue operations
+ get_start = time.perf_counter()
+ with self.action_queue_lock:
+ self.action_queue_size.append(self.action_queue.qsize())
+ # Get action from queue
+ timed_action = self.action_queue.get_nowait()
+ get_end = time.perf_counter() - get_start
+
+ _performed_action = self.robot.send_action(
+ self._action_tensor_to_action_dict(timed_action.get_action())
+ )
+ with self.latest_action_lock:
+ self.latest_action = timed_action.get_timestep()
+
+ if verbose:
+ with self.action_queue_lock:
+ current_queue_size = self.action_queue.qsize()
+
+ self.logger.debug(
+ f"Ts={timed_action.get_timestamp()} | "
+ f"Action #{timed_action.get_timestep()} performed | "
+ f"Queue size: {current_queue_size}"
+ )
+
+ self.logger.debug(
+ f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
+ )
+
+ return _performed_action
+
+ def _ready_to_send_observation(self):
+ """Flags when the client is ready to send an observation"""
+ with self.action_queue_lock:
+ return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
+
+ def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
+ try:
+ # Get serialized observation bytes from the function
+ start_time = time.perf_counter()
+
+ raw_observation: RawObservation = self.robot.get_observation()
+ raw_observation["task"] = task
+
+ with self.latest_action_lock:
+ latest_action = self.latest_action
+
+ observation = TimedObservation(
+ timestamp=time.time(), # need time.time() to compare timestamps across client and server
+ observation=raw_observation,
+ timestep=max(latest_action, 0),
+ )
+
+ obs_capture_time = time.perf_counter() - start_time
+
+ # If there are no actions left in the queue, the observation must go through processing!
+ with self.action_queue_lock:
+ observation.must_go = self.must_go.is_set() and self.action_queue.empty()
+ current_queue_size = self.action_queue.qsize()
+
+ _ = self.send_observation(observation)
+
+ self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
+ if observation.must_go:
+ # must-go event will be set again after receiving actions
+ self.must_go.clear()
+
+ if verbose:
+ # Calculate comprehensive FPS metrics
+ fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
+
+ self.logger.info(
+ f"Obs #{observation.get_timestep()} | "
+ f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
+ f"Target: {fps_metrics['target_fps']:.2f}"
+ )
+
+ self.logger.debug(
+ f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
+ )
+
+ return raw_observation
+
+ except Exception as e:
+ self.logger.error(f"Error in observation sender: {e}")
+
+ def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
+ """Combined function for executing actions and streaming observations"""
+ # Wait at barrier for synchronized start
+ self.start_barrier.wait()
+ self.logger.info("Control loop thread starting")
+
+ _performed_action = None
+ _captured_observation = None
+
+ while self.running:
+ control_loop_start = time.perf_counter()
+ """Control loop: (1) Performing actions, when available"""
+ if self.actions_available():
+ _performed_action = self.control_loop_action(verbose)
+
+ """Control loop: (2) Streaming observations to the remote policy server"""
+ if self._ready_to_send_observation():
+ _captured_observation = self.control_loop_observation(task, verbose)
+
+ self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
+ # Dynamically adjust sleep time to maintain the desired control frequency
+ time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
+
+ return _captured_observation, _performed_action
+
+
+@draccus.wrap()
+def async_client(cfg: RobotClientConfig):
+ logging.info(pformat(asdict(cfg)))
+
+ if cfg.robot.type not in SUPPORTED_ROBOTS:
+ raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
+
+ client = RobotClient(cfg)
+
+ if client.start():
+ client.logger.info("Starting action receiver thread...")
+
+ # Create and start action receiver thread
+ action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
+
+ # Start action receiver thread
+ action_receiver_thread.start()
+
+ try:
+ # The main thread runs the control loop
+ client.control_loop(task=cfg.task)
+
+ finally:
+ client.stop()
+ action_receiver_thread.join()
+ if cfg.debug_visualize_queue_size:
+ visualize_action_queue_size(client.action_queue_size)
+ client.logger.info("Client stopped")
+
+
+if __name__ == "__main__":
+ async_client() # run the client
diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py
index 1937205b1..e435c7309 100644
--- a/src/lerobot/cameras/camera.py
+++ b/src/lerobot/cameras/camera.py
@@ -15,7 +15,7 @@
# limitations under the License.
import abc
-from typing import Any, Dict, List
+from typing import Any
import numpy as np
@@ -69,7 +69,7 @@ class Camera(abc.ABC):
@staticmethod
@abc.abstractmethod
- def find_cameras() -> List[Dict[str, Any]]:
+ def find_cameras() -> list[dict[str, Any]]:
"""Detects available cameras connected to the system.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py
index fd99922a4..50e55f0c2 100644
--- a/src/lerobot/cameras/opencv/camera_opencv.py
+++ b/src/lerobot/cameras/opencv/camera_opencv.py
@@ -18,16 +18,20 @@ Provides the OpenCVCamera class for capturing frames from cameras using OpenCV.
import logging
import math
+import os
import platform
import time
from pathlib import Path
from threading import Event, Lock, Thread
-from typing import Any, Dict, List
+from typing import Any
+# Fix MSMF hardware transform compatibility for Windows before importing cv2
+if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
+ os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2
import numpy as np
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_backend, get_cv2_rotation
@@ -56,7 +60,7 @@ class OpenCVCamera(Camera):
or port changes, especially on Linux. Use the provided utility script to find
available camera indices or paths:
```bash
- python -m lerobot.find_cameras opencv
+ lerobot-find-cameras opencv
```
The camera's default settings (FPS, resolution, color mode) are used unless
@@ -161,8 +165,7 @@ class OpenCVCamera(Camera):
self.videocapture.release()
self.videocapture = None
raise ConnectionError(
- f"Failed to open {self}."
- f"Run `python -m lerobot.find_cameras opencv` to find available cameras."
+ f"Failed to open {self}.Run `lerobot-find-cameras opencv` to find available cameras."
)
self._configure_capture_settings()
@@ -241,7 +244,7 @@ class OpenCVCamera(Camera):
)
@staticmethod
- def find_cameras() -> List[Dict[str, Any]]:
+ def find_cameras() -> list[dict[str, Any]]:
"""
Detects available OpenCV cameras connected to the system.
@@ -364,7 +367,7 @@ class OpenCVCamera(Camera):
if requested_color_mode == ColorMode.RGB:
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
+ if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
processed_image = cv2.rotate(processed_image, self.rotation)
return processed_image
diff --git a/.github/workflows/trufflehog.yml b/src/lerobot/cameras/reachy2_camera/__init__.py
similarity index 57%
rename from .github/workflows/trufflehog.yml
rename to src/lerobot/cameras/reachy2_camera/__init__.py
index 704a3baaa..72e45f32a 100644
--- a/.github/workflows/trufflehog.yml
+++ b/src/lerobot/cameras/reachy2_camera/__init__.py
@@ -12,24 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-on:
- push:
-
-name: Secret Leaks
-
-permissions: {}
-
-jobs:
- trufflehog:
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- fetch-depth: 0
- persist-credentials: false
-
- - name: Secret Scanning
- uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
- with:
- extra_args: --only-verified
+from .configuration_reachy2_camera import Reachy2CameraConfig
+from .reachy2_camera import Reachy2Camera
diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py
new file mode 100644
index 000000000..5b2303ff2
--- /dev/null
+++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py
@@ -0,0 +1,78 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+from ..configs import CameraConfig, ColorMode
+
+
+@CameraConfig.register_subclass("reachy2_camera")
+@dataclass
+class Reachy2CameraConfig(CameraConfig):
+ """Configuration class for Reachy 2 camera devices.
+
+ This class provides configuration options for Reachy 2 cameras,
+ supporting both the teleop and depth cameras. It includes settings
+ for resolution, frame rate, color mode, and the selection of the cameras.
+
+ Example configurations:
+ ```python
+ # Basic configurations
+ Reachy2CameraConfig(
+ name="teleop",
+ image_type="left",
+ ip_address="192.168.0.200", # IP address of the robot
+ fps=15,
+ width=640,
+ height=480,
+ color_mode=ColorMode.RGB,
+ ) # Left teleop camera, 640x480 @ 15FPS
+ ```
+
+ Attributes:
+ name: Name of the camera device. Can be "teleop" or "depth".
+ image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
+ For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
+ fps: Requested frames per second for the color stream.
+ width: Requested frame width in pixels for the color stream.
+ height: Requested frame height in pixels for the color stream.
+ color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
+ ip_address: IP address of the robot. Defaults to "localhost".
+ port: Port number for the camera server. Defaults to 50065.
+
+ Note:
+ - Only 3-channel color output (RGB/BGR) is currently supported.
+ """
+
+ name: str
+ image_type: str
+ color_mode: ColorMode = ColorMode.RGB
+ ip_address: str | None = "localhost"
+ port: int = 50065
+ # use_depth: bool = False
+
+ def __post_init__(self):
+ if self.name not in ["teleop", "depth"]:
+ raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
+ if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
+ self.name == "depth" and self.image_type not in ["rgb", "depth"]
+ ):
+ raise ValueError(
+ f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
+ )
+
+ if self.color_mode not in ["rgb", "bgr"]:
+ raise ValueError(
+ f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
+ )
diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py
new file mode 100644
index 000000000..c96789f96
--- /dev/null
+++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py
@@ -0,0 +1,288 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
+"""
+
+import logging
+import os
+import platform
+import time
+from threading import Event, Lock, Thread
+from typing import Any
+
+# Fix MSMF hardware transform compatibility for Windows before importing cv2
+if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
+ os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
+import cv2
+import numpy as np
+from reachy2_sdk.media.camera import CameraView
+from reachy2_sdk.media.camera_manager import CameraManager
+
+from lerobot.utils.errors import DeviceNotConnectedError
+
+from ..camera import Camera
+from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
+
+logger = logging.getLogger(__name__)
+
+
+class Reachy2Camera(Camera):
+ """
+ Manages Reachy 2 camera using Reachy 2 CameraManager.
+
+ This class provides a high-level interface to connect to, configure, and read
+ frames from Reachy 2 cameras. It supports both synchronous and asynchronous
+ frame reading.
+
+ An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
+ type (e.g., "left") to be specified in the configuration.
+
+ The camera's default settings (FPS, resolution, color mode) are used unless
+ overridden in the configuration.
+ """
+
+ def __init__(self, config: Reachy2CameraConfig):
+ """
+ Initializes the Reachy2Camera instance.
+
+ Args:
+ config: The configuration settings for the camera.
+ """
+ super().__init__(config)
+
+ self.config = config
+
+ self.fps = config.fps
+ self.color_mode = config.color_mode
+
+ self.cam_manager: CameraManager | None = None
+
+ self.thread: Thread | None = None
+ self.stop_event: Event | None = None
+ self.frame_lock: Lock = Lock()
+ self.latest_frame: np.ndarray | None = None
+ self.new_frame_event: Event = Event()
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
+
+ @property
+ def is_connected(self) -> bool:
+ """Checks if the camera is currently connected and opened."""
+ if self.config.name == "teleop":
+ return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
+ elif self.config.name == "depth":
+ return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
+ else:
+ raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
+
+ def connect(self, warmup: bool = True):
+ """
+ Connects to the Reachy2 CameraManager as specified in the configuration.
+ """
+ self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
+ self.cam_manager.initialize_cameras()
+
+ logger.info(f"{self} connected.")
+
+ @staticmethod
+ def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
+ """
+ Detects available Reachy 2 cameras.
+
+ Returns:
+ List[Dict[str, Any]]: A list of dictionaries,
+ where each dictionary contains 'name', 'stereo',
+ and the default profile properties (width, height, fps).
+ """
+ initialized_cameras = []
+ camera_manager = CameraManager(host=ip_address, port=port)
+
+ for camera in [camera_manager.teleop, camera_manager.depth]:
+ if camera is None:
+ continue
+
+ height, width, _, _, _, _, _ = camera.get_parameters()
+
+ camera_info = {
+ "name": camera._cam_info.name,
+ "stereo": camera._cam_info.stereo,
+ "default_profile": {
+ "width": width,
+ "height": height,
+ "fps": 30,
+ },
+ }
+ initialized_cameras.append(camera_info)
+
+ camera_manager.disconnect()
+ return initialized_cameras
+
+ def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
+ """
+ Reads a single frame synchronously from the camera.
+
+ This is a blocking call.
+
+ Args:
+ color_mode (Optional[ColorMode]): If specified, overrides the default
+ color mode (`self.color_mode`) for this read operation (e.g.,
+ request RGB even if default is BGR).
+
+ Returns:
+ np.ndarray: The captured frame as a NumPy array in the format
+ (height, width, channels), using the specified or default
+ color mode and applying any configured rotation.
+ """
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ start_time = time.perf_counter()
+
+ frame = None
+
+ if self.cam_manager is None:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+ else:
+ if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
+ if self.config.image_type == "left":
+ frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
+ elif self.config.image_type == "right":
+ frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
+ elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
+ if self.config.image_type == "depth":
+ frame = self.cam_manager.depth.get_depth_frame()[0]
+ elif self.config.image_type == "rgb":
+ frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
+
+ if frame is None:
+ return np.empty((0, 0, 3), dtype=np.uint8)
+
+ if self.config.color_mode == "rgb":
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+
+ read_duration_ms = (time.perf_counter() - start_time) * 1e3
+ logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
+
+ return frame
+
+ def _read_loop(self):
+ """
+ Internal loop run by the background thread for asynchronous reading.
+
+ On each iteration:
+ 1. Reads a color frame
+ 2. Stores result in latest_frame (thread-safe)
+ 3. Sets new_frame_event to notify listeners
+
+ Stops on DeviceNotConnectedError, logs other errors and continues.
+ """
+ while not self.stop_event.is_set():
+ try:
+ color_image = self.read()
+
+ with self.frame_lock:
+ self.latest_frame = color_image
+ self.new_frame_event.set()
+
+ except DeviceNotConnectedError:
+ break
+ except Exception as e:
+ logger.warning(f"Error reading frame in background thread for {self}: {e}")
+
+ def _start_read_thread(self) -> None:
+ """Starts or restarts the background read thread if it's not running."""
+ if self.thread is not None and self.thread.is_alive():
+ self.thread.join(timeout=0.1)
+ if self.stop_event is not None:
+ self.stop_event.set()
+
+ self.stop_event = Event()
+ self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
+ self.thread.daemon = True
+ self.thread.start()
+
+ def _stop_read_thread(self) -> None:
+ """Signals the background read thread to stop and waits for it to join."""
+ if self.stop_event is not None:
+ self.stop_event.set()
+
+ if self.thread is not None and self.thread.is_alive():
+ self.thread.join(timeout=2.0)
+
+ self.thread = None
+ self.stop_event = None
+
+ def async_read(self, timeout_ms: float = 200) -> np.ndarray:
+ """
+ Reads the latest available frame asynchronously.
+
+ This method retrieves the most recent frame captured by the background
+ read thread. It does not block waiting for the camera hardware directly,
+ but may wait up to timeout_ms for the background thread to provide a frame.
+
+ Args:
+ timeout_ms (float): Maximum time in milliseconds to wait for a frame
+ to become available. Defaults to 200ms (0.2 seconds).
+
+ Returns:
+ np.ndarray: The latest captured frame as a NumPy array in the format
+ (height, width, channels), processed according to configuration.
+
+ Raises:
+ DeviceNotConnectedError: If the camera is not connected.
+ TimeoutError: If no frame becomes available within the specified timeout.
+ RuntimeError: If an unexpected error occurs.
+ """
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ if self.thread is None or not self.thread.is_alive():
+ self._start_read_thread()
+
+ if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
+ thread_alive = self.thread is not None and self.thread.is_alive()
+ raise TimeoutError(
+ f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
+ f"Read thread alive: {thread_alive}."
+ )
+
+ with self.frame_lock:
+ frame = self.latest_frame
+ self.new_frame_event.clear()
+
+ if frame is None:
+ raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
+
+ return frame
+
+ def disconnect(self):
+ """
+ Stops the background read thread (if running).
+
+ Raises:
+ DeviceNotConnectedError: If the camera is already disconnected.
+ """
+ if not self.is_connected and self.thread is None:
+ raise DeviceNotConnectedError(f"{self} not connected.")
+
+ if self.thread is not None:
+ self._stop_read_thread()
+
+ if self.cam_manager is not None:
+ self.cam_manager.disconnect()
+
+ logger.info(f"{self} disconnected.")
diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py
index 96531b694..cc816e552 100644
--- a/src/lerobot/cameras/realsense/camera_realsense.py
+++ b/src/lerobot/cameras/realsense/camera_realsense.py
@@ -19,7 +19,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
import logging
import time
from threading import Event, Lock, Thread
-from typing import Any, Dict, List
+from typing import Any
import cv2
import numpy as np
@@ -29,7 +29,7 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -51,7 +51,7 @@ class RealSenseCamera(Camera):
Use the provided utility script to find available camera indices and default profiles:
```bash
- python -m lerobot.find_cameras realsense
+ lerobot-find-cameras realsense
```
A `RealSenseCamera` instance requires a configuration object specifying the
@@ -176,8 +176,7 @@ class RealSenseCamera(Camera):
self.rs_profile = None
self.rs_pipeline = None
raise ConnectionError(
- f"Failed to open {self}."
- "Run `python -m lerobot.find_cameras realsense` to find available cameras."
+ f"Failed to open {self}.Run `lerobot-find-cameras realsense` to find available cameras."
) from e
self._configure_capture_settings()
@@ -194,7 +193,7 @@ class RealSenseCamera(Camera):
logger.info(f"{self} connected.")
@staticmethod
- def find_cameras() -> List[Dict[str, Any]]:
+ def find_cameras() -> list[dict[str, Any]]:
"""
Detects available Intel RealSense cameras connected to the system.
@@ -434,7 +433,7 @@ class RealSenseCamera(Camera):
if self.color_mode == ColorMode.BGR:
processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
+ if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
processed_image = cv2.rotate(processed_image, self.rotation)
return processed_image
diff --git a/src/lerobot/cameras/realsense/configuration_realsense.py b/src/lerobot/cameras/realsense/configuration_realsense.py
index 82e7c0d36..36a86876d 100644
--- a/src/lerobot/cameras/realsense/configuration_realsense.py
+++ b/src/lerobot/cameras/realsense/configuration_realsense.py
@@ -28,12 +28,12 @@ class RealSenseCameraConfig(CameraConfig):
Example configurations for Intel RealSense D405:
```python
# Basic configurations
- RealSenseCameraConfig("0123456789", 30, 1280, 720) # 1280x720 @ 30FPS
- RealSenseCameraConfig("0123456789", 60, 640, 480) # 640x480 @ 60FPS
+ RealSenseCameraConfig("0123456789", 30, 1280, 720) # 1280x720 @ 30FPS
+ RealSenseCameraConfig("0123456789", 60, 640, 480) # 640x480 @ 60FPS
# Advanced configurations
RealSenseCameraConfig("0123456789", 30, 640, 480, use_depth=True) # With depth sensing
- RealSenseCameraConfig("0123456789", 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
+ RealSenseCameraConfig("0123456789", 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
```
Attributes:
diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py
index f8bbd6e70..aa6ff98b4 100644
--- a/src/lerobot/cameras/utils.py
+++ b/src/lerobot/cameras/utils.py
@@ -15,19 +15,19 @@
# limitations under the License.
import platform
-from pathlib import Path
-from typing import TypeAlias
+from typing import cast
+
+from lerobot.utils.import_utils import make_device_from_device_class
from .camera import Camera
from .configs import CameraConfig, Cv2Rotation
-IndexOrPath: TypeAlias = int | Path
-
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
- cameras = {}
+ cameras: dict[str, Camera] = {}
for key, cfg in camera_configs.items():
+ # TODO(Steven): Consider just using the make_device_from_device_class for all types
if cfg.type == "opencv":
from .opencv import OpenCVCamera
@@ -37,8 +37,17 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
from .realsense.camera_realsense import RealSenseCamera
cameras[key] = RealSenseCamera(cfg)
+
+ elif cfg.type == "reachy2_camera":
+ from .reachy2_camera.reachy2_camera import Reachy2Camera
+
+ cameras[key] = Reachy2Camera(cfg)
+
else:
- raise ValueError(f"The motor type '{cfg.type}' is not valid.")
+ try:
+ cameras[key] = cast(Camera, make_device_from_device_class(cfg))
+ except Exception as e:
+ raise ValueError(f"Error creating camera {key} with config {cfg}: {e}") from e
return cameras
@@ -60,6 +69,8 @@ def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
- return cv2.CAP_AVFOUNDATION
- else:
+ return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
+ # elif platform.system() == "Darwin": # macOS
+ # return cv2.CAP_AVFOUNDATION
+ else: # Linux and others
return cv2.CAP_ANY
diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py
index 0e0bafdd9..fdd0d9ab4 100644
--- a/src/lerobot/configs/default.py
+++ b/src/lerobot/configs/default.py
@@ -16,9 +16,6 @@
from dataclasses import dataclass, field
-from lerobot import (
- policies, # noqa: F401
-)
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
@@ -37,6 +34,7 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
+ streaming: bool = False
@dataclass
diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py
index 1da7ad83f..2296eaa20 100644
--- a/src/lerobot/configs/parser.py
+++ b/src/lerobot/configs/parser.py
@@ -16,9 +16,9 @@ import inspect
import pkgutil
import sys
from argparse import ArgumentError
+from collections.abc import Sequence
from functools import wraps
from pathlib import Path
-from typing import Sequence
import draccus
@@ -76,9 +76,8 @@ def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
- Values are the corresponding argument values
Example:
- >>> args = ['--env.discover_packages_path=my_package',
- ... '--other_arg=value']
- >>> parse_plugin_args('discover_packages_path', args)
+ >>> args = ["--env.discover_packages_path=my_package", "--other_arg=value"]
+ >>> parse_plugin_args("discover_packages_path", args)
{'env.discover_packages_path': 'my_package'}
"""
plugin_args = {}
@@ -111,7 +110,7 @@ def load_plugin(plugin_path: str) -> None:
PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
Examples:
- >>> load_plugin("external_plugin.core") # Loads plugin from external package
+ >>> load_plugin("external_plugin.core") # Loads plugin from external package
Notes:
- The plugin package should handle its own registration during import
diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py
index c4502c841..af18860c7 100644
--- a/src/lerobot/configs/policies.py
+++ b/src/lerobot/configs/policies.py
@@ -12,24 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import builtins
+import json
import logging
import os
+import tempfile
from dataclasses import dataclass, field
from pathlib import Path
-from typing import Type, TypeVar
+from typing import TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
+from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
+from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
-# Generic variable that is either PreTrainedConfig or a subclass thereof
T = TypeVar("T", bound="PreTrainedConfig")
@@ -50,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
n_obs_steps: int = 1
- normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
@@ -72,9 +74,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None
+ # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
+ # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
+ pretrained_path: str | None = None
def __post_init__(self):
- self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
@@ -126,8 +130,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def robot_state_feature(self) -> PolicyFeature | None:
- for _, ft in self.input_features.items():
- if ft.type is FeatureType.STATE:
+ for ft_name, ft in self.input_features.items():
+ if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
return ft
return None
@@ -144,8 +148,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@property
def action_feature(self) -> PolicyFeature | None:
- for _, ft in self.output_features.items():
- if ft.type is FeatureType.ACTION:
+ for ft_name, ft in self.output_features.items():
+ if ft.type is FeatureType.ACTION and ft_name == ACTION:
return ft
return None
@@ -155,7 +159,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
@classmethod
def from_pretrained(
- cls: Type[T],
+ cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
@@ -192,8 +196,21 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
- # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
+ # HACK: Parse the original config to get the config subclass, so that we can
+ # apply cli overrides.
+ # This is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type)
+ with draccus.config_type("json"):
+ orig_config = draccus.parse(cls, config_file, args=[])
+
+ with open(config_file) as f:
+ config = json.load(f)
+
+ config.pop("type")
+ with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
+ json.dump(config, f)
+ config_file = f.name
+
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
- return draccus.parse(cls, config_file, args=cli_overrides)
+ return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py
index 23195619b..f56061e4e 100644
--- a/src/lerobot/configs/train.py
+++ b/src/lerobot/configs/train.py
@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import builtins
import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
-from typing import Type
import draccus
from huggingface_hub import hf_hub_download
@@ -137,7 +137,7 @@ class TrainPipelineConfig(HubMixin):
@classmethod
def from_pretrained(
- cls: Type["TrainPipelineConfig"],
+ cls: builtins.type["TrainPipelineConfig"],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py
index 6040ff70b..cb578060e 100644
--- a/src/lerobot/configs/types.py
+++ b/src/lerobot/configs/types.py
@@ -15,7 +15,6 @@
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
from dataclasses import dataclass
from enum import Enum
-from typing import Any, Protocol
class FeatureType(str, Enum):
@@ -24,16 +23,20 @@ class FeatureType(str, Enum):
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
+ LANGUAGE = "LANGUAGE"
+
+
+class PipelineFeatureType(str, Enum):
+ ACTION = "ACTION"
+ OBSERVATION = "OBSERVATION"
class NormalizationMode(str, Enum):
MIN_MAX = "MIN_MAX"
MEAN_STD = "MEAN_STD"
IDENTITY = "IDENTITY"
-
-
-class DictLike(Protocol):
- def __getitem__(self, key: Any) -> Any: ...
+ QUANTILES = "QUANTILES"
+ QUANTILE10 = "QUANTILE10"
@dataclass
diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py
new file mode 100644
index 000000000..803645f29
--- /dev/null
+++ b/src/lerobot/datasets/aggregate.py
@@ -0,0 +1,495 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import shutil
+from pathlib import Path
+
+import pandas as pd
+import tqdm
+
+from lerobot.datasets.compute_stats import aggregate_stats
+from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
+from lerobot.datasets.utils import (
+ DEFAULT_CHUNK_SIZE,
+ DEFAULT_DATA_FILE_SIZE_IN_MB,
+ DEFAULT_DATA_PATH,
+ DEFAULT_EPISODES_PATH,
+ DEFAULT_VIDEO_FILE_SIZE_IN_MB,
+ DEFAULT_VIDEO_PATH,
+ get_parquet_file_size_in_mb,
+ get_video_size_in_mb,
+ to_parquet_with_hf_images,
+ update_chunk_file_indices,
+ write_info,
+ write_stats,
+ write_tasks,
+)
+from lerobot.datasets.video_utils import concatenate_video_files
+
+
+def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
+ """Validates that all dataset metadata have consistent properties.
+
+ Ensures all datasets have the same fps, robot_type, and features to guarantee
+ compatibility when aggregating them into a single dataset.
+
+ Args:
+ all_metadata: List of LeRobotDatasetMetadata objects to validate.
+
+ Returns:
+ tuple: A tuple containing (fps, robot_type, features) from the first metadata.
+
+ Raises:
+ ValueError: If any metadata has different fps, robot_type, or features
+ than the first metadata in the list.
+ """
+
+ fps = all_metadata[0].fps
+ robot_type = all_metadata[0].robot_type
+ features = all_metadata[0].features
+
+ for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
+ if fps != meta.fps:
+ raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
+ if robot_type != meta.robot_type:
+ raise ValueError(
+ f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
+ )
+ if features != meta.features:
+ raise ValueError(
+ f"Same features is expected, but got features={meta.features} instead of {features}."
+ )
+
+ return fps, robot_type, features
+
+
+def update_data_df(df, src_meta, dst_meta):
+ """Updates a data DataFrame with new indices and task mappings for aggregation.
+
+ Adjusts episode indices, frame indices, and task indices to account for
+ previously aggregated data in the destination dataset.
+
+ Args:
+ df: DataFrame containing the data to be updated.
+ src_meta: Source dataset metadata.
+ dst_meta: Destination dataset metadata.
+
+ Returns:
+ pd.DataFrame: Updated DataFrame with adjusted indices.
+ """
+
+ df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
+ df["index"] = df["index"] + dst_meta.info["total_frames"]
+
+ src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
+ df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
+
+ return df
+
+
+def update_meta_data(
+ df,
+ dst_meta,
+ meta_idx,
+ data_idx,
+ videos_idx,
+):
+ """Updates metadata DataFrame with new chunk, file, and timestamp indices.
+
+ Adjusts all indices and timestamps to account for previously aggregated
+ data and videos in the destination dataset.
+
+ Args:
+ df: DataFrame containing the metadata to be updated.
+ dst_meta: Destination dataset metadata.
+ meta_idx: Dictionary containing current metadata chunk and file indices.
+ data_idx: Dictionary containing current data chunk and file indices.
+ videos_idx: Dictionary containing current video indices and timestamps.
+
+ Returns:
+ pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
+ """
+
+ df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
+ df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
+ df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
+ df["data/file_index"] = df["data/file_index"] + data_idx["file"]
+ for key, video_idx in videos_idx.items():
+ df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
+ df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
+ df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
+ df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
+
+ df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
+ df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
+ df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
+
+ return df
+
+
+def aggregate_datasets(
+ repo_ids: list[str],
+ aggr_repo_id: str,
+ roots: list[Path] | None = None,
+ aggr_root: Path | None = None,
+ data_files_size_in_mb: float | None = None,
+ video_files_size_in_mb: float | None = None,
+ chunk_size: int | None = None,
+):
+ """Aggregates multiple LeRobot datasets into a single unified dataset.
+
+ This is the main function that orchestrates the aggregation process by:
+ 1. Loading and validating all source dataset metadata
+ 2. Creating a new destination dataset with unified tasks
+ 3. Aggregating videos, data, and metadata from all source datasets
+ 4. Finalizing the aggregated dataset with proper statistics
+
+ Args:
+ repo_ids: List of repository IDs for the datasets to aggregate.
+ aggr_repo_id: Repository ID for the aggregated output dataset.
+ roots: Optional list of root paths for the source datasets.
+ aggr_root: Optional root path for the aggregated dataset.
+ data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
+ video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
+ chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
+ """
+ logging.info("Start aggregate_datasets")
+
+ if data_files_size_in_mb is None:
+ data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
+ if video_files_size_in_mb is None:
+ video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
+ if chunk_size is None:
+ chunk_size = DEFAULT_CHUNK_SIZE
+
+ all_metadata = (
+ [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
+ if roots is None
+ else [
+ LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
+ ]
+ )
+ fps, robot_type, features = validate_all_metadata(all_metadata)
+ video_keys = [key for key in features if features[key]["dtype"] == "video"]
+
+ dst_meta = LeRobotDatasetMetadata.create(
+ repo_id=aggr_repo_id,
+ fps=fps,
+ robot_type=robot_type,
+ features=features,
+ root=aggr_root,
+ )
+
+ logging.info("Find all tasks")
+ unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
+ dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
+
+ meta_idx = {"chunk": 0, "file": 0}
+ data_idx = {"chunk": 0, "file": 0}
+ videos_idx = {
+ key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
+ }
+
+ dst_meta.episodes = {}
+
+ for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
+ videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
+ data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
+
+ meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
+
+ dst_meta.info["total_episodes"] += src_meta.total_episodes
+ dst_meta.info["total_frames"] += src_meta.total_frames
+
+ finalize_aggregation(dst_meta, all_metadata)
+ logging.info("Aggregation complete.")
+
+
+def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
+ """Aggregates video chunks from a source dataset into the destination dataset.
+
+ Handles video file concatenation and rotation based on file size limits.
+ Creates new video files when size limits are exceeded.
+
+ Args:
+ src_meta: Source dataset metadata.
+ dst_meta: Destination dataset metadata.
+ videos_idx: Dictionary tracking video chunk and file indices.
+ video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
+ chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
+
+ Returns:
+ dict: Updated videos_idx with current chunk and file indices.
+ """
+ for key, video_idx in videos_idx.items():
+ unique_chunk_file_pairs = {
+ (chunk, file)
+ for chunk, file in zip(
+ src_meta.episodes[f"videos/{key}/chunk_index"],
+ src_meta.episodes[f"videos/{key}/file_index"],
+ strict=False,
+ )
+ }
+ unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
+
+ chunk_idx = video_idx["chunk"]
+ file_idx = video_idx["file"]
+
+ for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
+ src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
+ video_key=key,
+ chunk_index=src_chunk_idx,
+ file_index=src_file_idx,
+ )
+
+ dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
+ video_key=key,
+ chunk_index=chunk_idx,
+ file_index=file_idx,
+ )
+
+ # If a new file is created, we don't want to increment the latest_duration
+ update_latest_duration = False
+
+ if not dst_path.exists():
+ # First write to this destination file
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy(str(src_path), str(dst_path))
+ continue # not accumulating further, already copied the file in place
+
+ # Check file sizes before appending
+ src_size = get_video_size_in_mb(src_path)
+ dst_size = get_video_size_in_mb(dst_path)
+
+ if dst_size + src_size >= video_files_size_in_mb:
+ # Rotate to a new chunk/file
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
+ dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
+ video_key=key,
+ chunk_index=chunk_idx,
+ file_index=file_idx,
+ )
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy(str(src_path), str(dst_path))
+ else:
+ # Get the timestamps shift for this video
+ timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
+
+ # Append to existing video file
+ concatenate_video_files(
+ [dst_path, src_path],
+ dst_path,
+ )
+ # Update the latest_duration when appending (shifts timestamps!)
+ update_latest_duration = not update_latest_duration
+
+ # Update the videos_idx with the final chunk and file indices for this key
+ videos_idx[key]["chunk"] = chunk_idx
+ videos_idx[key]["file"] = file_idx
+
+ if update_latest_duration:
+ videos_idx[key]["latest_duration"] += timestamps_shift_s
+
+ return videos_idx
+
+
+def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
+ """Aggregates data chunks from a source dataset into the destination dataset.
+
+ Reads source data files, updates indices to match the aggregated dataset,
+ and writes them to the destination with proper file rotation.
+
+ Args:
+ src_meta: Source dataset metadata.
+ dst_meta: Destination dataset metadata.
+ data_idx: Dictionary tracking data chunk and file indices.
+
+ Returns:
+ dict: Updated data_idx with current chunk and file indices.
+ """
+ unique_chunk_file_ids = {
+ (c, f)
+ for c, f in zip(
+ src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False
+ )
+ }
+
+ unique_chunk_file_ids = sorted(unique_chunk_file_ids)
+
+ for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
+ src_path = src_meta.root / DEFAULT_DATA_PATH.format(
+ chunk_index=src_chunk_idx, file_index=src_file_idx
+ )
+ df = pd.read_parquet(src_path)
+ df = update_data_df(df, src_meta, dst_meta)
+
+ data_idx = append_or_create_parquet_file(
+ df,
+ src_path,
+ data_idx,
+ data_files_size_in_mb,
+ chunk_size,
+ DEFAULT_DATA_PATH,
+ contains_images=len(dst_meta.image_keys) > 0,
+ aggr_root=dst_meta.root,
+ )
+
+ return data_idx
+
+
+def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
+ """Aggregates metadata from a source dataset into the destination dataset.
+
+ Reads source metadata files, updates all indices and timestamps,
+ and writes them to the destination with proper file rotation.
+
+ Args:
+ src_meta: Source dataset metadata.
+ dst_meta: Destination dataset metadata.
+ meta_idx: Dictionary tracking metadata chunk and file indices.
+ data_idx: Dictionary tracking data chunk and file indices.
+ videos_idx: Dictionary tracking video indices and timestamps.
+
+ Returns:
+ dict: Updated meta_idx with current chunk and file indices.
+ """
+ chunk_file_ids = {
+ (c, f)
+ for c, f in zip(
+ src_meta.episodes["meta/episodes/chunk_index"],
+ src_meta.episodes["meta/episodes/file_index"],
+ strict=False,
+ )
+ }
+
+ chunk_file_ids = sorted(chunk_file_ids)
+ for chunk_idx, file_idx in chunk_file_ids:
+ src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ df = pd.read_parquet(src_path)
+ df = update_meta_data(
+ df,
+ dst_meta,
+ meta_idx,
+ data_idx,
+ videos_idx,
+ )
+
+ for k in videos_idx:
+ videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
+
+ meta_idx = append_or_create_parquet_file(
+ df,
+ src_path,
+ meta_idx,
+ DEFAULT_DATA_FILE_SIZE_IN_MB,
+ DEFAULT_CHUNK_SIZE,
+ DEFAULT_EPISODES_PATH,
+ contains_images=False,
+ aggr_root=dst_meta.root,
+ )
+
+ return meta_idx
+
+
+def append_or_create_parquet_file(
+ df: pd.DataFrame,
+ src_path: Path,
+ idx: dict[str, int],
+ max_mb: float,
+ chunk_size: int,
+ default_path: str,
+ contains_images: bool = False,
+ aggr_root: Path = None,
+):
+ """Appends data to an existing parquet file or creates a new one based on size constraints.
+
+ Manages file rotation when size limits are exceeded to prevent individual files
+ from becoming too large. Handles both regular parquet files and those containing images.
+
+ Args:
+ df: DataFrame to write to the parquet file.
+ src_path: Path to the source file (used for size estimation).
+ idx: Dictionary containing current 'chunk' and 'file' indices.
+ max_mb: Maximum allowed file size in MB before rotation.
+ chunk_size: Maximum number of files per chunk before incrementing chunk index.
+ default_path: Format string for generating file paths.
+ contains_images: Whether the data contains images requiring special handling.
+ aggr_root: Root path for the aggregated dataset.
+
+ Returns:
+ dict: Updated index dictionary with current chunk and file indices.
+ """
+ dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
+
+ if not dst_path.exists():
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+ if contains_images:
+ to_parquet_with_hf_images(df, dst_path)
+ else:
+ df.to_parquet(dst_path)
+ return idx
+
+ src_size = get_parquet_file_size_in_mb(src_path)
+ dst_size = get_parquet_file_size_in_mb(dst_path)
+
+ if dst_size + src_size >= max_mb:
+ idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
+ new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
+ new_path.parent.mkdir(parents=True, exist_ok=True)
+ final_df = df
+ target_path = new_path
+ else:
+ existing_df = pd.read_parquet(dst_path)
+ final_df = pd.concat([existing_df, df], ignore_index=True)
+ target_path = dst_path
+
+ if contains_images:
+ to_parquet_with_hf_images(final_df, target_path)
+ else:
+ final_df.to_parquet(target_path)
+
+ return idx
+
+
+def finalize_aggregation(aggr_meta, all_metadata):
+ """Finalizes the dataset aggregation by writing summary files and statistics.
+
+ Writes the tasks file, info file with total counts and splits, and
+ aggregated statistics from all source datasets.
+
+ Args:
+ aggr_meta: Aggregated dataset metadata.
+ all_metadata: List of all source dataset metadata objects.
+ """
+ logging.info("write tasks")
+ write_tasks(aggr_meta.tasks, aggr_meta.root)
+
+ logging.info("write info")
+ aggr_meta.info.update(
+ {
+ "total_tasks": len(aggr_meta.tasks),
+ "total_episodes": sum(m.total_episodes for m in all_metadata),
+ "total_frames": sum(m.total_frames for m in all_metadata),
+ "splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
+ }
+ )
+ write_info(aggr_meta.info, aggr_meta.root)
+
+ logging.info("write stats")
+ aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
+ write_stats(aggr_meta.stats, aggr_meta.root)
diff --git a/src/lerobot/datasets/backward_compatibility.py b/src/lerobot/datasets/backward_compatibility.py
index fae485058..ae95c5f7b 100644
--- a/src/lerobot/datasets/backward_compatibility.py
+++ b/src/lerobot/datasets/backward_compatibility.py
@@ -14,34 +14,17 @@
import packaging.version
-V2_MESSAGE = """
+V30_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
-We introduced a new format since v2.0 which is not backward compatible with v1.x.
-Please, use our conversion script. Modify the following command with your own task description:
+We introduced a new format since v3.0 which is not backward compatible with v2.1.
+Please, update your dataset to the new format using this command:
```
-python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \\
- --repo-id {repo_id} \\
- --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
+python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
```
-A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
-peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
-cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
-target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
-sweatshirt.", ...
-
-If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
-or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
-"""
-
-V21_MESSAGE = """
-The dataset you requested ({repo_id}) is in {version} format.
-While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
-stats instead of per-episode stats. Update your dataset stats to the new format using this command:
-```
-python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id={repo_id}
-```
+If you already have a converted version uploaded to the hub, then this error might be because of
+an older version in your local cache. Consider deleting the cached version and retrying.
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
@@ -58,7 +41,12 @@ class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
- message = V2_MESSAGE.format(repo_id=repo_id, version=version)
+ if version.major == 2 and version.minor == 1:
+ message = V30_MESSAGE.format(repo_id=repo_id, version=version)
+ else:
+ raise NotImplementedError(
+ "Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
+ )
super().__init__(message)
diff --git a/src/lerobot/datasets/card_template.md b/src/lerobot/datasets/card_template.md
index 7ee27df95..ee26a78f5 100644
--- a/src/lerobot/datasets/card_template.md
+++ b/src/lerobot/datasets/card_template.md
@@ -1,7 +1,8 @@
---
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
-{{ card_data }}
+# prettier-ignore
+{{card_data}}
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py
index bfe7b18b4..61e174d5c 100644
--- a/src/lerobot/datasets/compute_stats.py
+++ b/src/lerobot/datasets/compute_stats.py
@@ -17,6 +17,179 @@ import numpy as np
from lerobot.datasets.utils import load_image_as_numpy
+DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
+
+
+class RunningQuantileStats:
+ """
+ Maintains running statistics for batches of vectors, including mean,
+ standard deviation, min, max, and approximate quantiles.
+
+ Statistics are computed per feature dimension and updated incrementally
+ as new batches are observed. Quantiles are estimated using histograms,
+ which adapt dynamically if the observed data range expands.
+ """
+
+ def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
+ self._count = 0
+ self._mean = None
+ self._mean_of_squares = None
+ self._min = None
+ self._max = None
+ self._histograms = None
+ self._bin_edges = None
+ self._num_quantile_bins = num_quantile_bins
+
+ self._quantile_list = quantile_list
+ if self._quantile_list is None:
+ self._quantile_list = DEFAULT_QUANTILES
+ self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
+
+ def update(self, batch: np.ndarray) -> None:
+ """Update the running statistics with a batch of vectors.
+
+ Args:
+ batch: An array where all dimensions except the last are batch dimensions.
+ """
+ batch = batch.reshape(-1, batch.shape[-1])
+ num_elements, vector_length = batch.shape
+
+ if self._count == 0:
+ self._mean = np.mean(batch, axis=0)
+ self._mean_of_squares = np.mean(batch**2, axis=0)
+ self._min = np.min(batch, axis=0)
+ self._max = np.max(batch, axis=0)
+ self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
+ self._bin_edges = [
+ np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
+ for i in range(vector_length)
+ ]
+ else:
+ if vector_length != self._mean.size:
+ raise ValueError("The length of new vectors does not match the initialized vector length.")
+
+ new_max = np.max(batch, axis=0)
+ new_min = np.min(batch, axis=0)
+ max_changed = np.any(new_max > self._max)
+ min_changed = np.any(new_min < self._min)
+ self._max = np.maximum(self._max, new_max)
+ self._min = np.minimum(self._min, new_min)
+
+ if max_changed or min_changed:
+ self._adjust_histograms()
+
+ self._count += num_elements
+
+ batch_mean = np.mean(batch, axis=0)
+ batch_mean_of_squares = np.mean(batch**2, axis=0)
+
+ # Update running mean and mean of squares
+ self._mean += (batch_mean - self._mean) * (num_elements / self._count)
+ self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
+ num_elements / self._count
+ )
+
+ self._update_histograms(batch)
+
+ def get_statistics(self) -> dict[str, np.ndarray]:
+ """Compute and return the statistics of the vectors processed so far.
+
+ Args:
+ quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
+
+ Returns:
+ Dictionary containing the computed statistics.
+ """
+ if self._count < 2:
+ raise ValueError("Cannot compute statistics for less than 2 vectors.")
+
+ variance = self._mean_of_squares - self._mean**2
+
+ stddev = np.sqrt(np.maximum(0, variance))
+
+ stats = {
+ "min": self._min.copy(),
+ "max": self._max.copy(),
+ "mean": self._mean.copy(),
+ "std": stddev,
+ "count": np.array([self._count]),
+ }
+
+ quantile_results = self._compute_quantiles()
+ for i, q in enumerate(self._quantile_keys):
+ stats[q] = quantile_results[i]
+
+ return stats
+
+ def _adjust_histograms(self):
+ """Adjust histograms when min or max changes."""
+ for i in range(len(self._histograms)):
+ old_edges = self._bin_edges[i]
+ old_hist = self._histograms[i]
+
+ # Create new edges with small padding to ensure range coverage
+ padding = (self._max[i] - self._min[i]) * 1e-10
+ new_edges = np.linspace(
+ self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
+ )
+
+ # Redistribute existing histogram counts to new bins
+ # We need to map each old bin center to the new bins
+ old_centers = (old_edges[:-1] + old_edges[1:]) / 2
+ new_hist = np.zeros(self._num_quantile_bins)
+
+ for old_center, count in zip(old_centers, old_hist, strict=False):
+ if count > 0:
+ # Find which new bin this old center belongs to
+ bin_idx = np.searchsorted(new_edges, old_center) - 1
+ bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
+ new_hist[bin_idx] += count
+
+ self._histograms[i] = new_hist
+ self._bin_edges[i] = new_edges
+
+ def _update_histograms(self, batch: np.ndarray) -> None:
+ """Update histograms with new vectors."""
+ for i in range(batch.shape[1]):
+ hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
+ self._histograms[i] += hist
+
+ def _compute_quantiles(self) -> list[np.ndarray]:
+ """Compute quantiles based on histograms."""
+ results = []
+ for q in self._quantile_list:
+ target_count = q * self._count
+ q_values = []
+
+ for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
+ q_value = self._compute_single_quantile(hist, edges, target_count)
+ q_values.append(q_value)
+
+ results.append(np.array(q_values))
+ return results
+
+ def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
+ """Compute a single quantile value from histogram and bin edges."""
+ cumsum = np.cumsum(hist)
+ idx = np.searchsorted(cumsum, target_count)
+
+ if idx == 0:
+ return edges[0]
+ if idx >= len(cumsum):
+ return edges[-1]
+
+ # If not edge case, interpolate within the bin
+ count_before = cumsum[idx - 1]
+ count_in_bin = cumsum[idx] - count_before
+
+ # If no samples in this bin, use the bin edge
+ if count_in_bin == 0:
+ return edges[idx]
+
+ # Linear interpolation within the bin
+ fraction = (target_count - count_before) / count_in_bin
+ return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
+
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
@@ -72,33 +245,282 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
-def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
- return {
- "min": np.min(array, axis=axis, keepdims=keepdims),
- "max": np.max(array, axis=axis, keepdims=keepdims),
- "mean": np.mean(array, axis=axis, keepdims=keepdims),
- "std": np.std(array, axis=axis, keepdims=keepdims),
- "count": np.array([len(array)]),
+def _reshape_stats_by_axis(
+ stats: dict[str, np.ndarray],
+ axis: int | tuple[int, ...] | None,
+ keepdims: bool,
+ original_shape: tuple[int, ...],
+) -> dict[str, np.ndarray]:
+ """Reshape all statistics to match NumPy's output conventions.
+
+ Applies consistent reshaping to all statistics (except 'count') based on the
+ axis and keepdims parameters. This ensures statistics have the correct shape
+ for broadcasting with the original data.
+
+ Args:
+ stats: Dictionary of computed statistics
+ axis: Axis or axes along which statistics were computed
+ keepdims: Whether to keep reduced dimensions as size-1 dimensions
+ original_shape: Shape of the original array
+
+ Returns:
+ Dictionary with reshaped statistics
+
+ Note:
+ The 'count' statistic is never reshaped as it represents metadata
+ rather than per-feature statistics.
+ """
+ if axis == (1,) and not keepdims:
+ return stats
+
+ result = {}
+ for key, value in stats.items():
+ if key == "count":
+ result[key] = value
+ else:
+ result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
+
+ return result
+
+
+def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
+ """Reshape statistics for image data (axis=(0,2,3))."""
+ if keepdims and value.ndim == 1:
+ return value.reshape(1, -1, 1, 1)
+ return value
+
+
+def _reshape_for_vector_stats(
+ value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
+) -> np.ndarray:
+ """Reshape statistics for vector data (axis=0 or axis=(0,))."""
+ if not keepdims:
+ return value
+
+ if len(original_shape) == 1 and value.ndim > 0:
+ return value.reshape(1)
+ elif len(original_shape) >= 2 and value.ndim == 1:
+ return value.reshape(1, -1)
+ return value
+
+
+def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
+ """Reshape statistics for feature-wise computation (axis=(1,))."""
+ if not keepdims:
+ return value
+
+ if value.ndim == 0:
+ return value.reshape(1, 1)
+ elif value.ndim == 1:
+ return value.reshape(-1, 1)
+ return value
+
+
+def _reshape_for_global_stats(
+ value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
+) -> np.ndarray | float:
+ """Reshape statistics for global reduction (axis=None)."""
+ if keepdims:
+ target_shape = tuple(1 for _ in original_shape)
+ return value.reshape(target_shape)
+ # Keep at least 1-D arrays to satisfy validator
+ return np.atleast_1d(value)
+
+
+def _reshape_single_stat(
+ value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
+) -> np.ndarray | float:
+ """Apply appropriate reshaping to a single statistic array.
+
+ This function transforms statistic arrays to match expected output shapes
+ based on the axis configuration and keepdims parameter.
+
+ Args:
+ value: The statistic array to reshape
+ axis: Axis or axes that were reduced during computation
+ keepdims: Whether to maintain reduced dimensions as size-1 dimensions
+ original_shape: Shape of the original data before reduction
+
+ Returns:
+ Reshaped array following NumPy broadcasting conventions
+
+ """
+ if axis == (0, 2, 3):
+ return _reshape_for_image_stats(value, keepdims)
+
+ if axis in [0, (0,)]:
+ return _reshape_for_vector_stats(value, keepdims, original_shape)
+
+ if axis == (1,):
+ return _reshape_for_feature_stats(value, keepdims)
+
+ if axis is None:
+ return _reshape_for_global_stats(value, keepdims, original_shape)
+
+ return value
+
+
+def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
+ """Prepare array for statistics computation by reshaping according to axis.
+
+ Args:
+ array: Input data array
+ axis: Axis or axes along which to compute statistics
+
+ Returns:
+ Tuple of (reshaped_array, sample_count)
+ """
+ if axis == (0, 2, 3): # Image data
+ batch_size, channels, height, width = array.shape
+ reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
+ return reshaped, batch_size
+
+ if axis == 0 or axis == (0,): # Vector data
+ reshaped = array
+ if array.ndim == 1:
+ reshaped = array.reshape(-1, 1)
+ return reshaped, array.shape[0]
+
+ if axis == (1,): # Feature-wise statistics
+ return array.T, array.shape[1]
+
+ if axis is None: # Global statistics
+ reshaped = array.reshape(-1, 1)
+ # For backward compatibility, count represents the first dimension size
+ return reshaped, array.shape[0] if array.ndim > 0 else 1
+
+ raise ValueError(f"Unsupported axis configuration: {axis}")
+
+
+def _compute_basic_stats(
+ array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
+) -> dict[str, np.ndarray]:
+ """Compute basic statistics for arrays with insufficient samples for quantiles.
+
+ Args:
+ array: Reshaped array ready for statistics computation
+ sample_count: Number of samples represented in the data
+
+ Returns:
+ Dictionary with basic statistics and quantiles set to mean values
+ """
+ if quantile_list is None:
+ quantile_list = DEFAULT_QUANTILES
+ quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
+
+ stats = {
+ "min": np.min(array, axis=0),
+ "max": np.max(array, axis=0),
+ "mean": np.mean(array, axis=0),
+ "std": np.std(array, axis=0),
+ "count": np.array([sample_count]),
}
+ for q in quantile_list_keys:
+ stats[q] = stats["mean"].copy()
+
+ return stats
+
+
+def get_feature_stats(
+ array: np.ndarray,
+ axis: int | tuple[int, ...] | None,
+ keepdims: bool,
+ quantile_list: list[float] | None = None,
+) -> dict[str, np.ndarray]:
+ """Compute comprehensive statistics for array features along specified axes.
+
+ This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
+ for the input array along the specified axes. It handles different data layouts:
+ - Image data: axis=(0,2,3) computes per-channel statistics
+ - Vector data: axis=0 computes per-feature statistics
+ - Feature-wise: axis=1 computes statistics across features
+ - Global: axis=None computes statistics over entire array
+
+ Args:
+ array: Input data array with shape appropriate for the specified axis
+ axis: Axis or axes along which to compute statistics
+ - (0, 2, 3): For image data (batch, channels, height, width)
+ - 0 or (0,): For vector/tabular data (samples, features)
+ - (1,): For computing across features
+ - None: For global statistics over entire array
+ keepdims: If True, reduced axes are kept as dimensions with size 1
+
+ Returns:
+ Dictionary containing:
+ - 'min': Minimum values
+ - 'max': Maximum values
+ - 'mean': Mean values
+ - 'std': Standard deviation
+ - 'count': Number of samples (always shape (1,))
+ - 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
+
+ """
+ if quantile_list is None:
+ quantile_list = DEFAULT_QUANTILES
+
+ original_shape = array.shape
+ reshaped, sample_count = _prepare_array_for_stats(array, axis)
+
+ if reshaped.shape[0] < 2:
+ stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
+ else:
+ running_stats = RunningQuantileStats()
+ running_stats.update(reshaped)
+ stats = running_stats.get_statistics()
+ stats["count"] = np.array([sample_count])
+
+ stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
+ return stats
+
+
+def compute_episode_stats(
+ episode_data: dict[str, list[str] | np.ndarray],
+ features: dict,
+ quantile_list: list[float] | None = None,
+) -> dict:
+ """Compute comprehensive statistics for all features in an episode.
+
+ Processes different data types appropriately:
+ - Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
+ - Numerical arrays: Computes per-feature statistics
+ - Strings: Skipped (no statistics computed)
+
+ Args:
+ episode_data: Dictionary mapping feature names to data
+ - For images/videos: list of file paths
+ - For numerical data: numpy arrays
+ features: Dictionary describing each feature's dtype and shape
+
+ Returns:
+ Dictionary mapping feature names to their statistics dictionaries.
+ Each statistics dictionary contains min, max, mean, std, count, and quantiles.
+
+ Note:
+ Image statistics are normalized to [0,1] range and have shape (3,1,1) for
+ per-channel values when dtype is 'image' or 'video'.
+ """
+ if quantile_list is None:
+ quantile_list = DEFAULT_QUANTILES
-def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
- continue # HACK: we should receive np.arrays of strings
- elif features[key]["dtype"] in ["image", "video"]:
- ep_ft_array = sample_images(data) # data is a list of image paths
- axes_to_reduce = (0, 2, 3) # keep channel dim
+ continue
+
+ if features[key]["dtype"] in ["image", "video"]:
+ ep_ft_array = sample_images(data)
+ axes_to_reduce = (0, 2, 3)
keepdims = True
else:
- ep_ft_array = data # data is already a np.ndarray
- axes_to_reduce = 0 # compute stats over the first axis
- keepdims = data.ndim == 1 # keep as np.array
+ ep_ft_array = data
+ axes_to_reduce = 0
+ keepdims = data.ndim == 1
- ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
+ ep_stats[key] = get_feature_stats(
+ ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
+ )
- # finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
@@ -107,20 +529,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
return ep_stats
+def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
+ """Validate a single statistic value."""
+ if not isinstance(value, np.ndarray):
+ raise ValueError(
+ f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
+ f"is of type '{type(value)}' instead."
+ )
+
+ if value.ndim == 0:
+ raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
+
+ if key == "count" and value.shape != (1,):
+ raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
+
+ if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
+ raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
+
+
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
- for i in range(len(stats_list)):
- for fkey in stats_list[i]:
- for k, v in stats_list[i][fkey].items():
- if not isinstance(v, np.ndarray):
- raise ValueError(
- f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
- )
- if v.ndim == 0:
- raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
- if k == "count" and v.shape != (1,):
- raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
- if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
- raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
+ """Validate that all statistics have correct types and shapes.
+
+ Args:
+ stats_list: List of statistics dictionaries to validate
+
+ Raises:
+ ValueError: If any statistic has incorrect type or shape
+ """
+ for stats in stats_list:
+ for feature_key, feature_stats in stats.items():
+ for stat_key, stat_value in feature_stats.items():
+ _validate_stat_value(stat_value, stat_key, feature_key)
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
@@ -143,7 +582,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
- return {
+ aggregated = {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
@@ -151,6 +590,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
"count": total_count,
}
+ if stats_ft_list:
+ quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
+
+ for q_key in quantile_keys:
+ if all(q_key in s for s in stats_ft_list):
+ quantile_values = np.stack([s[q_key] for s in stats_ft_list])
+ weighted_quantiles = quantile_values * counts
+ aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
+
+ return aggregated
+
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py
index e06650bc9..f3ceb2b0c 100644
--- a/src/lerobot/datasets/factory.py
+++ b/src/lerobot/datasets/factory.py
@@ -25,7 +25,9 @@ from lerobot.datasets.lerobot_dataset import (
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
+from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
+from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -53,11 +55,11 @@ def resolve_delta_timestamps(
"""
delta_timestamps = {}
for key in ds_meta.features:
- if key == "next.reward" and cfg.reward_delta_indices is not None:
+ if key == REWARD and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
- if key == "action" and cfg.action_delta_indices is not None:
+ if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
- if key.startswith("observation.") and cfg.observation_delta_indices is not None:
+ if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
@@ -87,15 +89,26 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
- dataset = LeRobotDataset(
- cfg.dataset.repo_id,
- root=cfg.dataset.root,
- episodes=cfg.dataset.episodes,
- delta_timestamps=delta_timestamps,
- image_transforms=image_transforms,
- revision=cfg.dataset.revision,
- video_backend=cfg.dataset.video_backend,
- )
+ if not cfg.dataset.streaming:
+ dataset = LeRobotDataset(
+ cfg.dataset.repo_id,
+ root=cfg.dataset.root,
+ episodes=cfg.dataset.episodes,
+ delta_timestamps=delta_timestamps,
+ image_transforms=image_transforms,
+ revision=cfg.dataset.revision,
+ video_backend=cfg.dataset.video_backend,
+ )
+ else:
+ dataset = StreamingLeRobotDataset(
+ cfg.dataset.repo_id,
+ root=cfg.dataset.root,
+ episodes=cfg.dataset.episodes,
+ delta_timestamps=delta_timestamps,
+ image_transforms=image_transforms,
+ revision=cfg.dataset.revision,
+ max_num_shards=cfg.num_workers,
+ )
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py
index 1a3dd1e1b..b661b21b0 100644
--- a/src/lerobot/datasets/lerobot_dataset.py
+++ b/src/lerobot/datasets/lerobot_dataset.py
@@ -14,66 +14,73 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
+import gc
import logging
import shutil
+import tempfile
+from collections.abc import Callable
from pathlib import Path
-from typing import Callable
import datasets
import numpy as np
import packaging.version
+import pandas as pd
import PIL.Image
import torch
import torch.utils
-from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download
-from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
-from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
+ DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
INFO_PATH,
- TASKS_PATH,
_validate_feature_names,
- append_jsonlines,
- backward_compatible_episodes_stats,
check_delta_timestamps,
- check_timestamps_sync,
check_version_compatibility,
create_empty_dataset_info,
create_lerobot_dataset_card,
embed_images,
+ flatten_dict,
get_delta_indices,
- get_episode_data_index,
+ get_hf_dataset_cache_dir,
+ get_hf_dataset_size_in_mb,
get_hf_features_from_features,
+ get_parquet_file_size_in_mb,
+ get_parquet_num_frames,
get_safe_version,
+ get_video_size_in_mb,
hf_transform_to_torch,
is_valid_version,
load_episodes,
- load_episodes_stats,
load_info,
+ load_nested_dataset,
load_stats,
load_tasks,
+ to_parquet_with_hf_images,
+ update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
- write_episode,
- write_episode_stats,
write_info,
write_json,
+ write_stats,
+ write_tasks,
)
from lerobot.datasets.video_utils import (
VideoFrame,
+ concatenate_video_files,
decode_video_frames,
encode_video_frames,
get_safe_default_codec,
+ get_video_duration_in_s,
get_video_info,
)
+from lerobot.utils.constants import HF_LEROBOT_HOME
-CODEBASE_VERSION = "v2.1"
+CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
@@ -103,14 +110,9 @@ class LeRobotDatasetMetadata:
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
- self.tasks, self.task_to_task_index = load_tasks(self.root)
+ self.tasks = load_tasks(self.root)
self.episodes = load_episodes(self.root)
- if self._version < packaging.version.parse("v2.1"):
- self.stats = load_stats(self.root)
- self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
- else:
- self.episodes_stats = load_episodes_stats(self.root)
- self.stats = aggregate_stats(list(self.episodes_stats.values()))
+ self.stats = load_stats(self.root)
def pull_from_repo(
self,
@@ -126,24 +128,29 @@ class LeRobotDatasetMetadata:
ignore_patterns=ignore_patterns,
)
+ @property
+ def url_root(self) -> str:
+ return f"hf://datasets/{self.repo_id}"
+
@property
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
- ep_chunk = self.get_episode_chunk(ep_index)
- fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
+ ep = self.episodes[ep_index]
+ chunk_idx = ep["data/chunk_index"]
+ file_idx = ep["data/file_index"]
+ fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
- ep_chunk = self.get_episode_chunk(ep_index)
- fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
+ ep = self.episodes[ep_index]
+ chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
+ file_idx = ep[f"videos/{vid_key}/file_index"]
+ fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
- def get_episode_chunk(self, ep_index: int) -> int:
- return ep_index // self.chunks_size
-
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
@@ -210,39 +217,115 @@ class LeRobotDatasetMetadata:
return self.info["total_tasks"]
@property
- def total_chunks(self) -> int:
- """Total number of chunks (groups of episodes)."""
- return self.info["total_chunks"]
+ def chunks_size(self) -> int:
+ """Max number of files per chunk."""
+ return self.info["chunks_size"]
@property
- def chunks_size(self) -> int:
- """Max number of episodes per chunk."""
- return self.info["chunks_size"]
+ def data_files_size_in_mb(self) -> int:
+ """Max size of data file in mega bytes."""
+ return self.info["data_files_size_in_mb"]
+
+ @property
+ def video_files_size_in_mb(self) -> int:
+ """Max size of video file in mega bytes."""
+ return self.info["video_files_size_in_mb"]
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise return None.
"""
- return self.task_to_task_index.get(task, None)
+ if task in self.tasks.index:
+ return int(self.tasks.loc[task].task_index)
+ else:
+ return None
- def add_task(self, task: str):
+ def save_episode_tasks(self, tasks: list[str]):
+ if len(set(tasks)) != len(tasks):
+ raise ValueError(f"Tasks are not unique: {tasks}")
+
+ if self.tasks is None:
+ new_tasks = tasks
+ task_indices = range(len(tasks))
+ self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
+ else:
+ new_tasks = [task for task in tasks if task not in self.tasks.index]
+ new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
+ for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
+ self.tasks.loc[task] = task_idx
+
+ if len(new_tasks) > 0:
+ # Update on disk
+ write_tasks(self.tasks, self.root)
+
+ def _save_episode_metadata(self, episode_dict: dict) -> None:
+ """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
+
+ This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
+ and saves it as a parquet file. It handles both the creation of new parquet files and the
+ updating of existing ones based on size constraints. After saving the metadata, it reloads
+ the Hugging Face dataset to ensure it is up-to-date.
+
+ Notes: We both need to update parquet files and HF dataset:
+ - `pandas` loads parquet file in RAM
+ - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
+ or loads directly from pyarrow cache.
"""
- Given a task in natural language, add it to the dictionary of tasks.
- """
- if task in self.task_to_task_index:
- raise ValueError(f"The task '{task}' already exists and can't be added twice.")
+ # Convert buffer into HF Dataset
+ episode_dict = {key: [value] for key, value in episode_dict.items()}
+ ep_dataset = datasets.Dataset.from_dict(episode_dict)
+ ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
+ df = pd.DataFrame(ep_dataset)
+ num_frames = episode_dict["length"][0]
- task_index = self.info["total_tasks"]
- self.task_to_task_index[task] = task_index
- self.tasks[task_index] = task
- self.info["total_tasks"] += 1
+ if self.episodes is None:
+ # Initialize indices and frame count for a new dataset made of the first episode data
+ chunk_idx, file_idx = 0, 0
+ df["meta/episodes/chunk_index"] = [chunk_idx]
+ df["meta/episodes/file_index"] = [file_idx]
+ df["dataset_from_index"] = [0]
+ df["dataset_to_index"] = [num_frames]
+ else:
+ # Retrieve information from the latest parquet file
+ latest_ep = self.episodes[-1]
+ chunk_idx = latest_ep["meta/episodes/chunk_index"]
+ file_idx = latest_ep["meta/episodes/file_index"]
- task_dict = {
- "task_index": task_index,
- "task": task,
- }
- append_jsonlines(task_dict, self.root / TASKS_PATH)
+ latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
+
+ if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb:
+ # Size limit is reached, prepare new parquet file
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
+
+ # Update the existing pandas dataframe with new row
+ df["meta/episodes/chunk_index"] = [chunk_idx]
+ df["meta/episodes/file_index"] = [file_idx]
+ df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
+ df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames]
+
+ if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb:
+ # Size limit wasnt reached, concatenate latest dataframe with new one
+ latest_df = pd.read_parquet(latest_path)
+ df = pd.concat([latest_df, df], ignore_index=True)
+
+ # Memort optimization
+ del latest_df
+ gc.collect()
+
+ # Write the resulting dataframe from RAM to disk
+ path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ df.to_parquet(path, index=False)
+
+ if self.episodes is not None:
+ # Remove the episodes cache directory, necessary to avoid cache bloat
+ cached_dir = get_hf_dataset_cache_dir(self.episodes)
+ if cached_dir is not None:
+ shutil.rmtree(cached_dir)
+
+ self.episodes = load_episodes(self.root)
def save_episode(
self,
@@ -250,43 +333,91 @@ class LeRobotDatasetMetadata:
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
+ episode_metadata: dict,
) -> None:
- self.info["total_episodes"] += 1
- self.info["total_frames"] += episode_length
-
- chunk = self.get_episode_chunk(episode_index)
- if chunk >= self.total_chunks:
- self.info["total_chunks"] += 1
-
- self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
- self.info["total_videos"] += len(self.video_keys)
- if len(self.video_keys) > 0:
- self.update_video_info()
-
- write_info(self.info, self.root)
-
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
"length": episode_length,
}
- self.episodes[episode_index] = episode_dict
- write_episode(episode_dict, self.root)
+ episode_dict.update(episode_metadata)
+ episode_dict.update(flatten_dict({"stats": episode_stats}))
+ self._save_episode_metadata(episode_dict)
- self.episodes_stats[episode_index] = episode_stats
- self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
- write_episode_stats(episode_index, episode_stats, self.root)
+ # Update info
+ self.info["total_episodes"] += 1
+ self.info["total_frames"] += episode_length
+ self.info["total_tasks"] = len(self.tasks)
+ self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
- def update_video_info(self) -> None:
+ write_info(self.info, self.root)
+
+ self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
+ write_stats(self.stats, self.root)
+
+ def update_video_info(self, video_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
- for key in self.video_keys:
+ if video_key is not None and video_key not in self.video_keys:
+ raise ValueError(f"Video key {video_key} not found in dataset")
+
+ video_keys = [video_key] if video_key is not None else self.video_keys
+ for key in video_keys:
if not self.features[key].get("info", None):
- video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
+ video_path = self.root / self.video_path.format(
+ video_key=video_key, chunk_index=0, file_index=0
+ )
self.info["features"][key]["info"] = get_video_info(video_path)
+ def update_chunk_settings(
+ self,
+ chunks_size: int | None = None,
+ data_files_size_in_mb: int | None = None,
+ video_files_size_in_mb: int | None = None,
+ ) -> None:
+ """Update chunk and file size settings after dataset creation.
+
+ This allows users to customize storage organization without modifying the constructor.
+ These settings control how episodes are chunked and how large files can grow before
+ creating new ones.
+
+ Args:
+ chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
+ data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
+ video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
+ """
+ if chunks_size is not None:
+ if chunks_size <= 0:
+ raise ValueError(f"chunks_size must be positive, got {chunks_size}")
+ self.info["chunks_size"] = chunks_size
+
+ if data_files_size_in_mb is not None:
+ if data_files_size_in_mb <= 0:
+ raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
+ self.info["data_files_size_in_mb"] = data_files_size_in_mb
+
+ if video_files_size_in_mb is not None:
+ if video_files_size_in_mb <= 0:
+ raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
+ self.info["video_files_size_in_mb"] = video_files_size_in_mb
+
+ # Update the info file on disk
+ write_info(self.info, self.root)
+
+ def get_chunk_settings(self) -> dict[str, int]:
+ """Get current chunk and file size settings.
+
+ Returns:
+ Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
+ """
+ return {
+ "chunks_size": self.chunks_size,
+ "data_files_size_in_mb": self.data_files_size_in_mb,
+ "video_files_size_in_mb": self.video_files_size_in_mb,
+ }
+
def __repr__(self):
feature_keys = list(self.features)
return (
@@ -315,12 +446,12 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False)
- # TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
_validate_feature_names(features)
- obj.tasks, obj.task_to_task_index = {}, {}
- obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
+ obj.tasks = None
+ obj.episodes = None
+ obj.stats = None
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
@@ -336,12 +467,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
- delta_timestamps: dict[list[float]] | None = None,
+ delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
video_backend: str | None = None,
+ batch_encoding_size: int = 1,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -355,9 +487,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
the dataset from that address and load it, pending your dataset is compliant with
- codebase_version v2.0. If your dataset has been created before this new format, you will be
- prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at
- lerobot/datasets/v2/convert_dataset_v1_to_v2.py.
+ codebase_version v3.0. If your dataset has been created before this new format, you will be
+ prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at
+ lerobot/datasets/v30/convert_dataset_v21_to_v30.py.
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
@@ -378,38 +510,47 @@ class LeRobotDataset(torch.utils.data.Dataset):
.
├── data
│ ├── chunk-000
- │ │ ├── episode_000000.parquet
- │ │ ├── episode_000001.parquet
- │ │ ├── episode_000002.parquet
+ │ │ ├── file-000.parquet
+ │ │ ├── file-001.parquet
│ │ └── ...
│ ├── chunk-001
- │ │ ├── episode_001000.parquet
- │ │ ├── episode_001001.parquet
- │ │ ├── episode_001002.parquet
+ │ │ ├── file-000.parquet
+ │ │ ├── file-001.parquet
│ │ └── ...
│ └── ...
├── meta
- │ ├── episodes.jsonl
+ │ ├── episodes
+ │ │ ├── chunk-000
+ │ │ │ ├── file-000.parquet
+ │ │ │ ├── file-001.parquet
+ │ │ │ └── ...
+ │ │ ├── chunk-001
+ │ │ │ └── ...
+ │ │ └── ...
│ ├── info.json
│ ├── stats.json
- │ └── tasks.jsonl
+ │ └── tasks.parquet
└── videos
- ├── chunk-000
- │ ├── observation.images.laptop
- │ │ ├── episode_000000.mp4
- │ │ ├── episode_000001.mp4
- │ │ ├── episode_000002.mp4
+ ├── observation.images.laptop
+ │ ├── chunk-000
+ │ │ ├── file-000.mp4
+ │ │ ├── file-001.mp4
│ │ └── ...
- │ ├── observation.images.phone
- │ │ ├── episode_000000.mp4
- │ │ ├── episode_000001.mp4
- │ │ ├── episode_000002.mp4
+ │ ├── chunk-001
│ │ └── ...
- ├── chunk-001
+ │ └── ...
+ ├── observation.images.phone
+ │ ├── chunk-000
+ │ │ ├── file-000.mp4
+ │ │ ├── file-001.mp4
+ │ │ └── ...
+ │ ├── chunk-001
+ │ │ └── ...
+ │ └── ...
└── ...
- Note that this file-based structure is designed to be as versatile as possible. The files are split by
- episodes which allows a more granular control over which episodes one wants to use and download. The
+ Note that this file-based structure is designed to be as versatile as possible. Multiple episodes are
+ consolidated into chunked files which improves storage efficiency and loading performance. The
structure of the dataset is entirely described in the info.json file, which can be easily downloaded
or viewed directly on the hub before downloading any actual data. The type of files used are very
simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
@@ -434,7 +575,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
multiples of 1/fps. Defaults to 1e-4.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. Defaults to current codebase version tag.
- sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
+ force_cache_sync (bool, optional): Flag to sync and refresh local files first. If True and files
are already present in the local cache, this will be faster. However, files loaded might not
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
False.
@@ -443,6 +584,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
+ batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
+ Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
"""
super().__init__()
self.repo_id = repo_id
@@ -454,6 +597,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
+ self.batch_encoding_size = batch_encoding_size
+ self.episodes_since_last_encoding = 0
# Unused attributes
self.image_writer = None
@@ -465,29 +610,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
- if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
- episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
- self.stats = aggregate_stats(episodes_stats)
# Load actual data
try:
if force_cache_sync:
raise FileNotFoundError
- assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
+ # Check if cached dataset contains all requested episodes
+ if not self._check_cached_episodes_sufficient():
+ raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
- self.download_episodes(download_videos)
+ self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
- self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
-
- # Check timestamps
- timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
- episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
- ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
- check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
-
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@@ -538,11 +674,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
hub_api.upload_folder(**upload_kwargs)
- if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
- card = create_lerobot_dataset_card(
- tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
- )
- card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
+ card = create_lerobot_dataset_card(
+ tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
+ )
+ card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
@@ -563,7 +698,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
ignore_patterns=ignore_patterns,
)
- def download_episodes(self, download_videos: bool = True) -> None:
+ def download(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@@ -571,11 +706,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
- files = None
ignore_patterns = None if download_videos else "videos/"
+ files = None
if self.episodes is not None:
files = self.get_episodes_file_paths()
-
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
@@ -588,28 +722,43 @@ class LeRobotDataset(torch.utils.data.Dataset):
for ep_idx in episodes
]
fpaths += video_files
-
+ # episodes are stored in the same files, so we return unique paths only
+ fpaths = list(set(fpaths))
return fpaths
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
- if self.episodes is None:
- path = str(self.root / "data")
- hf_dataset = load_dataset("parquet", data_dir=path, split="train")
- else:
- files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
- hf_dataset = load_dataset("parquet", data_files=files, split="train")
-
- # TODO(aliberts): hf_dataset.set_format("torch")
+ features = get_hf_features_from_features(self.features)
+ hf_dataset = load_nested_dataset(self.root / "data", features=features)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
+ def _check_cached_episodes_sufficient(self) -> bool:
+ """Check if the cached dataset contains all requested episodes."""
+ if self.hf_dataset is None or len(self.hf_dataset) == 0:
+ return False
+
+ # Get available episode indices from cached dataset
+ available_episodes = {
+ ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
+ for ep_idx in self.hf_dataset["episode_index"]
+ }
+
+ # Determine requested episodes
+ if self.episodes is None:
+ # Requesting all episodes - check if we have all episodes from metadata
+ requested_episodes = set(range(self.meta.total_episodes))
+ else:
+ # Requesting specific episodes
+ requested_episodes = set(self.episodes)
+
+ # Check if all requested episodes are available in cached data
+ return requested_episodes.issubset(available_episodes)
+
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
-
- # TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -641,15 +790,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
- ep_start = self.episode_data_index["from"][ep_idx]
- ep_end = self.episode_data_index["to"][ep_idx]
+ ep = self.meta.episodes[ep_idx]
+ ep_start = ep["dataset_from_index"]
+ ep_end = ep["dataset_to_index"]
query_indices = {
- key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
+ key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
- [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
+ [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
)
for key, delta_idx in self.delta_indices.items()
}
@@ -663,7 +813,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
- timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
+ timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -672,7 +822,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
- key: torch.stack(self.hf_dataset.select(q_idx)[key])
+ key: torch.stack(self.hf_dataset[q_idx][key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
@@ -683,19 +833,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
the main process and a subprocess fails to access it.
"""
+ ep = self.meta.episodes[ep_idx]
item = {}
for vid_key, query_ts in query_timestamps.items():
+ # Episodes are stored sequentially on a single mp4 to reduce the number of files.
+ # Thus we load the start timestamp of the episode on this mp4 and,
+ # shift the query timestamp accordingly.
+ from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
+ shifted_query_ts = [from_timestamp + ts for ts in query_ts]
+
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
- frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
+ frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend)
item[vid_key] = frames.squeeze(0)
return item
- def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
- for key, val in padding.items():
- item[key] = torch.BoolTensor(val)
- return item
-
def __len__(self):
return self.num_frames
@@ -724,8 +876,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
- item["task"] = self.meta.tasks[task_idx]
-
+ item["task"] = self.meta.tasks.iloc[task_idx].name
return item
def __repr__(self):
@@ -755,6 +906,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
return self.root / fpath
+ def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
+ return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
+
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
@@ -763,7 +917,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
self.image_writer.save_image(image=image, fpath=fpath)
- def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None:
+ def add_frame(self, frame: dict) -> None:
"""
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
@@ -781,11 +935,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
- if timestamp is None:
- timestamp = frame_index / self.fps
+ timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
- self.episode_buffer["task"].append(task)
+ self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
# Add frame features to episode_buffer
for key in frame:
@@ -811,13 +964,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
This will save to disk the current episode in self.episode_buffer.
+ Video encoding is handled automatically based on batch_encoding_size:
+ - If batch_encoding_size == 1: Videos are encoded immediately after each episode
+ - If batch_encoding_size > 1: Videos are encoded in batches.
+
Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
"""
- if not episode_data:
- episode_buffer = self.episode_buffer
+ episode_buffer = episode_data if episode_data is not None else self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
@@ -830,11 +986,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
- # Add new tasks to the tasks dictionary
- for task in episode_tasks:
- task_index = self.meta.get_task_index(task)
- if task_index is None:
- self.meta.add_task(task)
+ # Update tasks and task indices with new tasks if any
+ self.meta.save_episode_tasks(episode_tasks)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
@@ -846,59 +999,234 @@ class LeRobotDataset(torch.utils.data.Dataset):
continue
episode_buffer[key] = np.stack(episode_buffer[key])
+ # Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer()
- self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features)
- if len(self.meta.video_keys) > 0:
- video_paths = self.encode_episode_videos(episode_index)
- for key in self.meta.video_keys:
- episode_buffer[key] = video_paths[key]
+ ep_metadata = self._save_episode_data(episode_buffer)
+ has_video_keys = len(self.meta.video_keys) > 0
+ use_batched_encoding = self.batch_encoding_size > 1
- # `meta.save_episode` be executed after encoding the videos
- self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
+ if has_video_keys and not use_batched_encoding:
+ for video_key in self.meta.video_keys:
+ ep_metadata.update(self._save_episode_video(video_key, episode_index))
- ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
- ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
- check_timestamps_sync(
- episode_buffer["timestamp"],
- episode_buffer["episode_index"],
- ep_data_index_np,
- self.fps,
- self.tolerance_s,
+ # `meta.save_episode` need to be executed after encoding the videos
+ self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
+
+ if has_video_keys and use_batched_encoding:
+ # Check if we should trigger batch encoding
+ self.episodes_since_last_encoding += 1
+ if self.episodes_since_last_encoding == self.batch_encoding_size:
+ start_ep = self.num_episodes - self.batch_encoding_size
+ end_ep = self.num_episodes
+ self._batch_save_episode_video(start_ep, end_ep)
+ self.episodes_since_last_encoding = 0
+
+ if not episode_data:
+ # Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
+ self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
+
+ def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
+ """
+ Batch save videos for multiple episodes.
+
+ Args:
+ start_episode: Starting episode index (inclusive)
+ end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
+ """
+ if end_episode is None:
+ end_episode = self.num_episodes
+
+ logging.info(
+ f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
)
- video_files = list(self.root.rglob("*.mp4"))
- assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
+ chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"]
+ file_idx = self.meta.episodes[start_episode]["data/file_index"]
+ episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ episode_df = pd.read_parquet(episode_df_path)
- parquet_files = list(self.root.rglob("*.parquet"))
- assert len(parquet_files) == self.num_episodes
+ for ep_idx in range(start_episode, end_episode):
+ logging.info(f"Encoding videos for episode {ep_idx}")
- # delete images
- img_dir = self.root / "images"
- if img_dir.is_dir():
- shutil.rmtree(self.root / "images")
+ if (
+ self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
+ or self.meta.episodes[ep_idx]["data/file_index"] != file_idx
+ ):
+ # The current episode is in a new chunk or file.
+ # Save previous episode dataframe and update the Hugging Face dataset by reloading it.
+ episode_df.to_parquet(episode_df_path)
+ self.meta.episodes = load_episodes(self.root)
- if not episode_data: # Reset the buffer
- self.episode_buffer = self.create_episode_buffer()
+ # Load new episode dataframe
+ chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"]
+ file_idx = self.meta.episodes[ep_idx]["data/file_index"]
+ episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(
+ chunk_index=chunk_idx, file_index=file_idx
+ )
+ episode_df = pd.read_parquet(episode_df_path)
- def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
- episode_dict = {key: episode_buffer[key] for key in self.hf_features}
- ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
+ # Save the current episode's video metadata to the dataframe
+ video_ep_metadata = {}
+ for video_key in self.meta.video_keys:
+ video_ep_metadata.update(self._save_episode_video(video_key, ep_idx))
+ video_ep_metadata.pop("episode_index")
+ video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes(
+ dtype_backend="pyarrow"
+ ) # allows NaN values along with integers
+
+ episode_df = episode_df.combine_first(video_ep_df)
+ episode_df.to_parquet(episode_df_path)
+ self.meta.episodes = load_episodes(self.root)
+
+ def _save_episode_data(self, episode_buffer: dict) -> dict:
+ """Save episode data to a parquet file and update the Hugging Face dataset of frames data.
+
+ This function processes episodes data from a buffer, converts it into a Hugging Face dataset,
+ and saves it as a parquet file. It handles both the creation of new parquet files and the
+ updating of existing ones based on size constraints. After saving the data, it reloads
+ the Hugging Face dataset to ensure it is up-to-date.
+
+ Notes: We both need to update parquet files and HF dataset:
+ - `pandas` loads parquet file in RAM
+ - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
+ or loads directly from pyarrow cache.
+ """
+ # Convert buffer into HF Dataset
+ ep_dict = {key: episode_buffer[key] for key in self.hf_features}
+ ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
- self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
- self.hf_dataset.set_transform(hf_transform_to_torch)
- ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
- ep_data_path.parent.mkdir(parents=True, exist_ok=True)
- ep_dataset.to_parquet(ep_data_path)
+ ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
+ ep_num_frames = len(ep_dataset)
+ df = pd.DataFrame(ep_dataset)
- def clear_episode_buffer(self) -> None:
- episode_index = self.episode_buffer["episode_index"]
- if self.image_writer is not None:
+ if self.meta.episodes is None:
+ # Initialize indices and frame count for a new dataset made of the first episode data
+ chunk_idx, file_idx = 0, 0
+ latest_num_frames = 0
+ else:
+ # Retrieve information from the latest parquet file
+ latest_ep = self.meta.episodes[-1]
+ chunk_idx = latest_ep["data/chunk_index"]
+ file_idx = latest_ep["data/file_index"]
+
+ latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
+ latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
+ latest_num_frames = get_parquet_num_frames(latest_path)
+
+ # Determine if a new parquet file is needed
+ if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb:
+ # Size limit is reached, prepare new parquet file
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
+ latest_num_frames = 0
+ else:
+ # Update the existing parquet file with new rows
+ latest_df = pd.read_parquet(latest_path)
+ df = pd.concat([latest_df, df], ignore_index=True)
+
+ # Memort optimization
+ del latest_df
+ gc.collect()
+
+ # Write the resulting dataframe from RAM to disk
+ path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if len(self.meta.image_keys) > 0:
+ to_parquet_with_hf_images(df, path)
+ else:
+ df.to_parquet(path)
+
+ if self.hf_dataset is not None:
+ # Remove hf dataset cache directory, necessary to avoid cache bloat
+ cached_dir = get_hf_dataset_cache_dir(self.hf_dataset)
+ if cached_dir is not None:
+ shutil.rmtree(cached_dir)
+
+ self.hf_dataset = self.load_hf_dataset()
+
+ metadata = {
+ "data/chunk_index": chunk_idx,
+ "data/file_index": file_idx,
+ "dataset_from_index": latest_num_frames,
+ "dataset_to_index": latest_num_frames + ep_num_frames,
+ }
+ return metadata
+
+ def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
+ # Encode episode frames into a temporary video
+ ep_path = self._encode_temporary_episode_video(video_key, episode_index)
+ ep_size_in_mb = get_video_size_in_mb(ep_path)
+ ep_duration_in_s = get_video_duration_in_s(ep_path)
+
+ if self.meta.episodes is None or (
+ f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names
+ or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names
+ ):
+ # Initialize indices for a new dataset made of the first episode data
+ chunk_idx, file_idx = 0, 0
+ latest_duration_in_s = 0.0
+ new_path = self.root / self.meta.video_path.format(
+ video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
+ )
+ new_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.move(str(ep_path), str(new_path))
+ else:
+ # Retrieve information from the latest updated video file (possibly several episodes ago)
+ latest_ep = self.meta.episodes[episode_index - 1]
+ chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
+ file_idx = latest_ep[f"videos/{video_key}/file_index"]
+
+ latest_path = self.root / self.meta.video_path.format(
+ video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
+ )
+ latest_size_in_mb = get_video_size_in_mb(latest_path)
+ latest_duration_in_s = get_video_duration_in_s(latest_path)
+
+ if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb:
+ # Move temporary episode video to a new video file in the dataset
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
+ new_path = self.root / self.meta.video_path.format(
+ video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
+ )
+ new_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.move(str(ep_path), str(new_path))
+ latest_duration_in_s = 0.0
+ else:
+ # Update latest video file
+ concatenate_video_files(
+ [latest_path, ep_path],
+ latest_path,
+ )
+
+ # Remove temporary directory
+ shutil.rmtree(str(ep_path.parent))
+
+ # Update video info (only needed when first episode is encoded since it reads from episode 0)
+ if episode_index == 0:
+ self.meta.update_video_info(video_key)
+ write_info(self.meta.info, self.meta.root) # ensure video info always written properly
+
+ metadata = {
+ "episode_index": episode_index,
+ f"videos/{video_key}/chunk_index": chunk_idx,
+ f"videos/{video_key}/file_index": file_idx,
+ f"videos/{video_key}/from_timestamp": latest_duration_in_s,
+ f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
+ }
+ return metadata
+
+ def clear_episode_buffer(self, delete_images: bool = True) -> None:
+ # Clean up image files for the current episode buffer
+ if delete_images:
+ # Wait for the async image writer to finish
+ if self.image_writer is not None:
+ self._wait_image_writer()
+ episode_index = self.episode_buffer["episode_index"]
+ if isinstance(episode_index, np.ndarray):
+ episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
for cam_key in self.meta.camera_keys:
- img_dir = self._get_image_file_path(
- episode_index=episode_index, image_key=cam_key, frame_index=0
- ).parent
+ img_dir = self._get_image_file_dir(episode_index, cam_key)
if img_dir.is_dir():
shutil.rmtree(img_dir)
@@ -919,7 +1247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def stop_image_writer(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
- remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
+ remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()
@@ -930,34 +1258,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None:
self.image_writer.wait_until_done()
- def encode_videos(self) -> None:
+ def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
- for ep_idx in range(self.meta.total_episodes):
- self.encode_episode_videos(ep_idx)
-
- def encode_episode_videos(self, episode_index: int) -> dict:
- """
- Use ffmpeg to convert frames stored as png into mp4 videos.
- Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
- since video encoding with ffmpeg is already using multithreading.
- """
- video_paths = {}
- for key in self.meta.video_keys:
- video_path = self.root / self.meta.get_video_file_path(episode_index, key)
- video_paths[key] = str(video_path)
- if video_path.is_file():
- # Skip if video is already encoded. Could be the case when resuming data recording.
- continue
- img_dir = self._get_image_file_path(
- episode_index=episode_index, image_key=key, frame_index=0
- ).parent
- encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
-
- return video_paths
+ temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
+ img_dir = self._get_image_file_dir(episode_index, video_key)
+ encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
+ shutil.rmtree(img_dir)
+ return temp_path
@classmethod
def create(
@@ -972,6 +1283,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_processes: int = 0,
image_writer_threads: int = 0,
video_backend: str | None = None,
+ batch_encoding_size: int = 1,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls)
@@ -988,6 +1300,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
+ obj.batch_encoding_size = batch_encoding_size
+ obj.episodes_since_last_encoding = 0
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
@@ -1000,7 +1314,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
- obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj
@@ -1018,7 +1331,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
- delta_timestamps: dict[list[float]] | None = None,
+ delta_timestamps: dict[str, list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
@@ -1078,11 +1391,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
- @property
- def repo_index_to_id(self):
- """Return the inverse mapping if repo_id_to_index."""
- return {v: k for k, v in self.repo_id_to_index}
-
@property
def fps(self) -> int:
"""Frames per second used during data collection.
@@ -1113,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
- if isinstance(feats, (datasets.Image, VideoFrame)):
+ if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
diff --git a/src/lerobot/datasets/online_buffer.py b/src/lerobot/datasets/online_buffer.py
index 79f48f49d..563d800b9 100644
--- a/src/lerobot/datasets/online_buffer.py
+++ b/src/lerobot/datasets/online_buffer.py
@@ -337,13 +337,11 @@ def compute_sampler_weights(
if len(offline_dataset) > 0:
offline_data_mask_indices = []
for start_index, end_index in zip(
- offline_dataset.episode_data_index["from"],
- offline_dataset.episode_data_index["to"],
+ offline_dataset.meta.episodes["dataset_from_index"],
+ offline_dataset.meta.episodes["dataset_to_index"],
strict=True,
):
- offline_data_mask_indices.extend(
- range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
- )
+ offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
weights.append(
diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py
new file mode 100644
index 000000000..4fad7bd20
--- /dev/null
+++ b/src/lerobot/datasets/pipeline_features.py
@@ -0,0 +1,139 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from collections.abc import Sequence
+from typing import Any
+
+from lerobot.configs.types import PipelineFeatureType
+from lerobot.datasets.utils import hw_to_dataset_features
+from lerobot.processor import DataProcessorPipeline
+from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
+
+
+def create_initial_features(
+ action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None
+) -> dict[PipelineFeatureType, dict[str, Any]]:
+ """
+ Creates the initial features dict for the dataset from action and observation specs.
+
+ Args:
+ action: A dictionary of action feature names to their types/shapes.
+ observation: A dictionary of observation feature names to their types/shapes.
+
+ Returns:
+ The initial features dictionary structured by PipelineFeatureType.
+ """
+ features = {PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}}
+ if action:
+ features[PipelineFeatureType.ACTION] = action
+ if observation:
+ features[PipelineFeatureType.OBSERVATION] = observation
+ return features
+
+
+# Helper to filter state/action keys based on regex patterns.
+def should_keep(key: str, patterns: tuple[str]) -> bool:
+ if patterns is None:
+ return True
+ return any(re.search(pat, key) for pat in patterns)
+
+
+def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
+ for prefix in prefixes_to_strip:
+ if key.startswith(prefix):
+ return key[len(prefix) :]
+ return key
+
+
+# Define prefixes to strip from feature keys for clean names.
+# Handles both fully qualified (e.g., "action.state") and short (e.g., "state") forms.
+PREFIXES_TO_STRIP = tuple(
+ f"{token}." for const in (ACTION, OBS_STATE, OBS_IMAGES) for token in (const, const.split(".")[-1])
+)
+
+
+def aggregate_pipeline_dataset_features(
+ pipeline: DataProcessorPipeline,
+ initial_features: dict[PipelineFeatureType, dict[str, Any]],
+ *,
+ use_videos: bool = True,
+ patterns: Sequence[str] | None = None,
+) -> dict[str, dict]:
+ """
+ Aggregates and filters pipeline features to create a dataset-ready features dictionary.
+
+ This function transforms initial features using the pipeline, categorizes them as action or observations
+ (image or state), filters them based on `use_videos` and `patterns`, and finally
+ formats them for use with a Hugging Face LeRobot Dataset.
+
+ Args:
+ pipeline: The DataProcessorPipeline to apply.
+ initial_features: A dictionary of raw feature specs for actions and observations.
+ use_videos: If False, image features are excluded.
+ patterns: A sequence of regex patterns to filter action and state features.
+ Image features are not affected by this filter.
+
+ Returns:
+ A dictionary of features formatted for a Hugging Face LeRobot Dataset.
+ """
+ all_features = pipeline.transform_features(initial_features)
+
+ # Intermediate storage for categorized and filtered features.
+ processed_features: dict[str, dict[str, Any]] = {
+ ACTION: {},
+ OBS_STR: {},
+ }
+ images_token = OBS_IMAGES.split(".")[-1]
+
+ # Iterate through all features transformed by the pipeline.
+ for ptype, feats in all_features.items():
+ if ptype not in [PipelineFeatureType.ACTION, PipelineFeatureType.OBSERVATION]:
+ continue
+
+ for key, value in feats.items():
+ # 1. Categorize the feature.
+ is_action = ptype == PipelineFeatureType.ACTION
+ # Observations are classified as images if their key matches image-related tokens or if the shape of the feature is 3.
+ # All other observations are treated as state.
+ is_image = not is_action and (
+ (isinstance(value, tuple) and len(value) == 3)
+ or (
+ key.startswith(f"{OBS_IMAGES}.")
+ or key.startswith(f"{images_token}.")
+ or f".{images_token}." in key
+ )
+ )
+
+ # 2. Apply filtering rules.
+ if is_image and not use_videos:
+ continue
+ if not is_image and not should_keep(key, patterns):
+ continue
+
+ # 3. Add the feature to the appropriate group with a clean name.
+ name = strip_prefix(key, PREFIXES_TO_STRIP)
+ if is_action:
+ processed_features[ACTION][name] = value
+ else:
+ processed_features[OBS_STR][name] = value
+
+ # Convert the processed features into the final dataset format.
+ dataset_features = {}
+ if processed_features[ACTION]:
+ dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
+ if processed_features[OBS_STR]:
+ dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
+
+ return dataset_features
diff --git a/src/lerobot/datasets/push_dataset_to_hub/utils.py b/src/lerobot/datasets/push_dataset_to_hub/utils.py
index 6aca7b03b..48214e1bf 100644
--- a/src/lerobot/datasets/push_dataset_to_hub/utils.py
+++ b/src/lerobot/datasets/push_dataset_to_hub/utils.py
@@ -13,71 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
-from concurrent.futures import ThreadPoolExecutor
-from pathlib import Path
-from typing import Dict
import datasets
-import numpy
-import PIL
import torch
-from lerobot.datasets.video_utils import encode_video_frames
-
-
-def concatenate_episodes(ep_dicts):
- data_dict = {}
-
- keys = ep_dicts[0].keys()
- for key in keys:
- if torch.is_tensor(ep_dicts[0][key][0]):
- data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
- else:
- if key not in data_dict:
- data_dict[key] = []
- for ep_dict in ep_dicts:
- for x in ep_dict[key]:
- data_dict[key].append(x)
-
- total_frames = data_dict["frame_index"].shape[0]
- data_dict["index"] = torch.arange(0, total_frames, 1)
- return data_dict
-
-
-def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
- out_dir = Path(out_dir)
- out_dir.mkdir(parents=True, exist_ok=True)
-
- def save_image(img_array, i, out_dir):
- img = PIL.Image.fromarray(img_array)
- img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
-
- num_images = len(imgs_array)
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
-
-
-def get_default_encoding() -> dict:
- """Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
- signature = inspect.signature(encode_video_frames)
- return {
- k: v.default
- for k, v in signature.parameters.items()
- if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
- }
-
-
-def check_repo_id(repo_id: str) -> None:
- if len(repo_id.split("/")) != 2:
- raise ValueError(
- f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
- (e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
- )
-
# TODO(aliberts): remove
-def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
+def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py
index 2f6c15c15..d0bb20c27 100644
--- a/src/lerobot/datasets/sampler.py
+++ b/src/lerobot/datasets/sampler.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterator, Union
+from collections.abc import Iterator
import torch
@@ -21,8 +21,9 @@ import torch
class EpisodeAwareSampler:
def __init__(
self,
- episode_data_index: dict,
- episode_indices_to_use: Union[list, None] = None,
+ dataset_from_indices: list[int],
+ dataset_to_indices: list[int],
+ episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
@@ -30,7 +31,8 @@ class EpisodeAwareSampler:
"""Sampler that optionally incorporates episode boundary information.
Args:
- episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
+ dataset_from_indices: List of indices containing the start of each episode in the dataset.
+ dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
@@ -39,12 +41,10 @@ class EpisodeAwareSampler:
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
- zip(episode_data_index["from"], episode_data_index["to"], strict=True)
+ zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
- indices.extend(
- range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
- )
+ indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
self.indices = indices
self.shuffle = shuffle
diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py
new file mode 100644
index 000000000..454389d46
--- /dev/null
+++ b/src/lerobot/datasets/streaming_dataset.py
@@ -0,0 +1,533 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections.abc import Callable, Generator, Iterator
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+from datasets import load_dataset
+
+from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
+from lerobot.datasets.utils import (
+ Backtrackable,
+ LookAheadError,
+ LookBackError,
+ check_version_compatibility,
+ find_float_index,
+ get_delta_indices,
+ is_float_in_list,
+ item_to_torch,
+ safe_shard,
+)
+from lerobot.datasets.video_utils import (
+ VideoDecoderCache,
+ decode_video_frames_torchcodec,
+)
+from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
+
+
+class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
+ """LeRobotDataset with streaming capabilities.
+
+ This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
+ rather than loaded entirely into memory. This is especially useful for large datasets that may
+ not fit in memory or when you want to quickly explore a dataset without downloading it completely.
+
+ The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
+ items, allowing us to access previous frames for delta timestamps without loading the entire
+ dataset into memory.
+
+ Example:
+ Basic usage:
+ ```python
+ from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
+
+ # Create a streaming dataset with delta timestamps
+ delta_timestamps = {
+ "observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
+ "action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
+ }
+
+ dataset = StreamingLeRobotDataset(
+ repo_id="your-dataset-repo-id",
+ delta_timestamps=delta_timestamps,
+ streaming=True,
+ buffer_size=1000,
+ )
+
+ # Iterate over the dataset
+ for i, item in enumerate(dataset):
+ print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
+ # item will contain stacked frames according to delta_timestamps
+ if i >= 10:
+ break
+ ```
+ """
+
+ def __init__(
+ self,
+ repo_id: str,
+ root: str | Path | None = None,
+ episodes: list[int] | None = None,
+ image_transforms: Callable | None = None,
+ delta_timestamps: dict[list[float]] | None = None,
+ tolerance_s: float = 1e-4,
+ revision: str | None = None,
+ force_cache_sync: bool = False,
+ streaming: bool = True,
+ buffer_size: int = 1000,
+ max_num_shards: int = 16,
+ seed: int = 42,
+ rng: np.random.Generator | None = None,
+ shuffle: bool = True,
+ ):
+ """Initialize a StreamingLeRobotDataset.
+
+ Args:
+ repo_id (str): This is the repo id that will be used to fetch the dataset.
+ root (Path | None, optional): Local directory to use for downloading/writing files.
+ episodes (list[int] | None, optional): If specified, this will only load episodes specified by
+ their episode_index in this list.
+ image_transforms (Callable | None, optional): Transform to apply to image data.
+ tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
+ revision (str, optional): Git revision id (branch name, tag, or commit hash).
+ force_cache_sync (bool, optional): Flag to sync and refresh local files first.
+ streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
+ buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
+ max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
+ seed (int, optional): Reproducibility random seed.
+ rng (np.random.Generator | None, optional): Random number generator.
+ shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
+ """
+ super().__init__()
+ self.repo_id = repo_id
+ self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
+ self.streaming_from_local = root is not None
+
+ self.image_transforms = image_transforms
+ self.episodes = episodes
+ self.tolerance_s = tolerance_s
+ self.revision = revision if revision else CODEBASE_VERSION
+ self.seed = seed
+ self.rng = rng if rng is not None else np.random.default_rng(seed)
+ self.shuffle = shuffle
+
+ self.streaming = streaming
+ self.buffer_size = buffer_size
+
+ # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
+ self.video_decoder_cache = None
+
+ self.root.mkdir(exist_ok=True, parents=True)
+
+ # Load metadata
+ self.meta = LeRobotDatasetMetadata(
+ self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
+ )
+ # Check version
+ check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
+
+ self.delta_timestamps = None
+ self.delta_indices = None
+
+ if delta_timestamps is not None:
+ self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
+ self.delta_timestamps = delta_timestamps
+ self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
+
+ self.hf_dataset: datasets.IterableDataset = load_dataset(
+ self.repo_id if not self.streaming_from_local else str(self.root),
+ split="train",
+ streaming=self.streaming,
+ data_files="data/*/*.parquet",
+ revision=self.revision,
+ )
+
+ self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
+
+ @property
+ def num_frames(self):
+ return self.meta.total_frames
+
+ @property
+ def num_episodes(self):
+ return self.meta.total_episodes
+
+ @property
+ def fps(self):
+ return self.meta.fps
+
+ @staticmethod
+ def _iter_random_indices(
+ rng: np.random.Generator, buffer_size: int, random_batch_size=100
+ ) -> Iterator[int]:
+ while True:
+ yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
+
+ @staticmethod
+ def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
+ while True:
+ yield rng.choice(elements)
+
+ # TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
+ # The current sequential iteration is a bottleneck. A producer-consumer pattern
+ # could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
+ # in parallel, feeding a queue from which this iterator will yield processed items.
+ def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
+ if self.video_decoder_cache is None:
+ self.video_decoder_cache = VideoDecoderCache()
+
+ # keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
+ rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
+
+ buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
+
+ idx_to_backtrack_dataset = {
+ idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
+ for idx in range(self.num_shards)
+ }
+
+ # This buffer is populated while iterating on the dataset's shards
+ # the logic is to add 2 levels of randomness:
+ # (1) sample one shard at random from the ones available, and
+ # (2) sample one frame from the shard sampled at (1)
+ frames_buffer = []
+ while available_shards := list(idx_to_backtrack_dataset.keys()):
+ shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
+ backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
+
+ try:
+ for frame in self.make_frame(backtrack_dataset):
+ if len(frames_buffer) == self.buffer_size:
+ i = next(buffer_indices_generator) # samples a element from the buffer
+ yield frames_buffer[i]
+ frames_buffer[i] = frame
+ else:
+ frames_buffer.append(frame)
+ break # random shard sampled, switch shard
+ except (
+ RuntimeError,
+ StopIteration,
+ ): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
+ del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
+
+ # Once shards are all exhausted, shuffle the buffer and yield the remaining frames
+ rng.shuffle(frames_buffer)
+ yield from frames_buffer
+
+ def _get_window_steps(
+ self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
+ ) -> tuple[int, int]:
+ if delta_timestamps is None:
+ return 1, 1
+
+ if not dynamic_bounds:
+ # Fix the windows
+ lookback = LOOKBACK_BACKTRACKTABLE
+ lookahead = LOOKAHEAD_BACKTRACKTABLE
+ else:
+ # Dynamically adjust the windows based on the given delta_timesteps
+ all_timestamps = sum(delta_timestamps.values(), [])
+ lookback = min(all_timestamps) * self.fps
+ lookahead = max(all_timestamps) * self.fps
+
+ # When lookback is >=0 it means no negative timesteps have been provided
+ lookback = 0 if lookback >= 0 else (lookback * -1)
+
+ return lookback, lookahead
+
+ def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
+ lookback, lookahead = self._get_window_steps(self.delta_timestamps)
+ return Backtrackable(dataset, history=lookback, lookahead=lookahead)
+
+ def _make_timestamps_from_indices(
+ self, start_ts: float, indices: dict[str, list[int]] | None = None
+ ) -> dict[str, list[float]]:
+ if indices is not None:
+ return {
+ key: (
+ start_ts + torch.tensor(indices[key]) / self.fps
+ ).tolist() # NOTE: why not delta_timestamps directly?
+ for key in self.delta_timestamps
+ }
+ else:
+ return dict.fromkeys(self.meta.video_keys, [start_ts])
+
+ def _make_padding_camera_frame(self, camera_key: str):
+ """Variable-shape padding frame for given camera keys, given in (H, W, C)"""
+ return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
+
+ def _get_video_frame_padding_mask(
+ self,
+ video_frames: dict[str, torch.Tensor],
+ query_timestamps: dict[str, list[float]],
+ original_timestamps: dict[str, list[float]],
+ ) -> dict[str, torch.BoolTensor]:
+ padding_mask = {}
+
+ for video_key, timestamps in original_timestamps.items():
+ if video_key not in video_frames:
+ continue # only padding on video keys that are available
+ frames = []
+ mask = []
+ padding_frame = self._make_padding_camera_frame(video_key)
+ for ts in timestamps:
+ if is_float_in_list(ts, query_timestamps[video_key]):
+ idx = find_float_index(ts, query_timestamps[video_key])
+ frames.append(video_frames[video_key][idx, :])
+ mask.append(False)
+ else:
+ frames.append(padding_frame)
+ mask.append(True)
+
+ padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask)
+
+ return padding_mask
+
+ def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
+ """Makes a frame starting from a dataset iterator"""
+ item = next(dataset_iterator)
+ item = item_to_torch(item)
+
+ updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
+
+ # Get episode index from the item
+ ep_idx = item["episode_index"]
+
+ # "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
+ current_ts = item["index"] / self.fps
+
+ episode_boundaries_ts = {
+ key: (
+ self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
+ self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
+ )
+ for key in self.meta.video_keys
+ }
+
+ # Apply delta querying logic if necessary
+ if self.delta_indices is not None:
+ query_result, padding = self._get_delta_frames(dataset_iterator, item)
+ updates.append(query_result)
+ updates.append(padding)
+
+ # Load video frames, when needed
+ if len(self.meta.video_keys) > 0:
+ original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
+
+ # Some timestamps might not result available considering the episode's boundaries
+ query_timestamps = self._get_query_timestamps(
+ current_ts, self.delta_indices, episode_boundaries_ts
+ )
+ video_frames = self._query_videos(query_timestamps, ep_idx)
+
+ if self.image_transforms is not None:
+ image_keys = self.meta.camera_keys
+ for cam in image_keys:
+ video_frames[cam] = self.image_transforms(video_frames[cam])
+
+ updates.append(video_frames)
+
+ if self.delta_indices is not None:
+ # We always return the same number of frames. Unavailable frames are padded.
+ padding_mask = self._get_video_frame_padding_mask(
+ video_frames, query_timestamps, original_timestamps
+ )
+ updates.append(padding_mask)
+
+ result = item.copy()
+ for update in updates:
+ result.update(update)
+
+ result["task"] = self.meta.tasks.iloc[item["task_index"]].name
+
+ yield result
+
+ def _get_query_timestamps(
+ self,
+ current_ts: float,
+ query_indices: dict[str, list[int]] | None = None,
+ episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
+ ) -> dict[str, list[float]]:
+ query_timestamps = {}
+ keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
+ for key in self.meta.video_keys:
+ if query_indices is not None and key in query_indices:
+ timestamps = keys_to_timestamps[key]
+ # Clamp out timesteps outside of episode boundaries
+ query_timestamps[key] = torch.clamp(
+ torch.tensor(timestamps), *episode_boundaries_ts[key]
+ ).tolist()
+
+ else:
+ query_timestamps[key] = [current_ts]
+
+ return query_timestamps
+
+ def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
+ """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
+ in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
+ Segmentation Fault. This probably happens because a memory reference to the video loader is created in
+ the main process and a subprocess fails to access it.
+ """
+
+ item = {}
+ for video_key, query_ts in query_timestamps.items():
+ root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
+ video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
+ frames = decode_video_frames_torchcodec(
+ video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
+ )
+
+ item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
+
+ return item
+
+ def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
+ # TODO(fracapuano): Modularize this function, refactor the code
+ """Get frames with delta offsets using the backtrackable iterator.
+
+ Args:
+ current_item (dict): Current item from the iterator.
+ ep_idx (int): Episode index.
+
+ Returns:
+ tuple: (query_result, padding) - frames at delta offsets and padding info.
+ """
+ current_episode_idx = current_item["episode_index"]
+
+ # Prepare results
+ query_result = {}
+ padding = {}
+
+ for key, delta_indices in self.delta_indices.items():
+ if key in self.meta.video_keys:
+ continue # visual frames are decoded separately
+
+ target_frames = []
+ is_pad = []
+
+ # Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
+ delta_results = {}
+
+ # Separate and sort deltas by difficulty (easier operations first)
+ negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
+ positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
+ zero_deltas = [d for d in delta_indices if d == 0]
+
+ # Process zero deltas (current frame)
+ for delta in zero_deltas:
+ delta_results[delta] = (
+ current_item[key],
+ False,
+ )
+
+ # Process negative deltas in order of increasing difficulty
+ lookback_failed = False
+
+ last_successful_frame = current_item[key]
+
+ for delta in negative_deltas:
+ if lookback_failed:
+ delta_results[delta] = (last_successful_frame, True)
+ continue
+
+ try:
+ steps_back = abs(delta)
+ if dataset_iterator.can_peek_back(steps_back):
+ past_item = dataset_iterator.peek_back(steps_back)
+ past_item = item_to_torch(past_item)
+
+ if past_item["episode_index"] == current_episode_idx:
+ delta_results[delta] = (past_item[key], False)
+ last_successful_frame = past_item[key]
+
+ else:
+ raise LookBackError("Retrieved frame is from different episode!")
+ else:
+ raise LookBackError("Cannot go back further than the history buffer!")
+
+ except LookBackError:
+ delta_results[delta] = (last_successful_frame, True)
+ lookback_failed = True # All subsequent negative deltas will also fail
+
+ # Process positive deltas in order of increasing difficulty
+ lookahead_failed = False
+ last_successful_frame = current_item[key]
+
+ for delta in positive_deltas:
+ if lookahead_failed:
+ delta_results[delta] = (last_successful_frame, True)
+ continue
+
+ try:
+ if dataset_iterator.can_peek_ahead(delta):
+ future_item = dataset_iterator.peek_ahead(delta)
+ future_item = item_to_torch(future_item)
+
+ if future_item["episode_index"] == current_episode_idx:
+ delta_results[delta] = (future_item[key], False)
+ last_successful_frame = future_item[key]
+
+ else:
+ raise LookAheadError("Retrieved frame is from different episode!")
+ else:
+ raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
+
+ except LookAheadError:
+ delta_results[delta] = (last_successful_frame, True)
+ lookahead_failed = True # All subsequent positive deltas will also fail
+
+ # Reconstruct original order for stacking
+ for delta in delta_indices:
+ frame, is_padded = delta_results[delta]
+
+ # add batch dimension for stacking
+ target_frames.append(frame) # frame.unsqueeze(0))
+ is_pad.append(is_padded)
+
+ # Stack frames and add to results
+ if target_frames:
+ query_result[key] = torch.stack(target_frames)
+ padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
+
+ return query_result, padding
+
+ def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
+ """
+ Validate that all keys in delta_timestamps correspond to actual features in the dataset.
+
+ Raises:
+ ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
+ """
+ if delta_timestamps is None:
+ return
+
+ # Get all available feature keys from the dataset metadata
+ available_features = set(self.meta.features.keys())
+
+ # Get all keys from delta_timestamps
+ delta_keys = set(delta_timestamps.keys())
+
+ # Find any keys that don't correspond to features
+ invalid_keys = delta_keys - available_features
+
+ if invalid_keys:
+ raise ValueError(
+ f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
+ f"Available features are: {sorted(available_features)}"
+ )
diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py
index 3ac1d5771..f7072c72f 100644
--- a/src/lerobot/datasets/transforms.py
+++ b/src/lerobot/datasets/transforms.py
@@ -14,13 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
+from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
-from typing import Any, Callable, Sequence
+from typing import Any
import torch
from torchvision.transforms import v2
-from torchvision.transforms.v2 import Transform
-from torchvision.transforms.v2 import functional as F # noqa: N812
+from torchvision.transforms.v2 import (
+ Transform,
+ functional as F, # noqa: N812
+)
class RandomSubsetApply(Transform):
@@ -117,7 +120,7 @@ class SharpnessJitter(Transform):
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
- if isinstance(sharpness, (int, float)):
+ if isinstance(sharpness, (int | float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py
index ac0ab9799..a2f285014 100644
--- a/src/lerobot/datasets/utils.py
+++ b/src/lerobot/datasets/utils.py
@@ -17,54 +17,56 @@ import contextlib
import importlib.resources
import json
import logging
-from collections.abc import Iterator
-from itertools import accumulate
+from collections import deque
+from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
-from types import SimpleNamespace
-from typing import Any
+from typing import Any, Generic, TypeVar
import datasets
-import jsonlines
import numpy as np
import packaging.version
+import pandas
+import pandas as pd
+import pyarrow.parquet as pq
import torch
+from datasets import Dataset, concatenate_datasets
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
-from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
+from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.backward_compatibility import (
- V21_MESSAGE,
+ FUTURE_MESSAGE,
BackwardCompatibilityError,
ForwardCompatibilityError,
)
-from lerobot.robots import Robot
+from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
from lerobot.utils.utils import is_valid_numpy_dtype_string
-DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
+DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
+DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
+DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
INFO_PATH = "meta/info.json"
-EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
-EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
-TASKS_PATH = "meta/tasks.jsonl"
-DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
-DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
-DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
+EPISODES_DIR = "meta/episodes"
+DATA_DIR = "data"
+VIDEO_DIR = "videos"
-DATASET_CARD_TEMPLATE = """
----
-# Metadata will go there
----
-This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
+CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
+DEFAULT_TASKS_PATH = "meta/tasks.parquet"
+DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
+DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
+DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
+DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
-## {}
-
-"""
+LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
+LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
+LEGACY_TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
@@ -74,15 +76,83 @@ DEFAULT_FEATURES = {
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
+T = TypeVar("T")
+
+
+def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
+ metadata = pq.read_metadata(parquet_path)
+ total_uncompressed_size = 0
+ for row_group in range(metadata.num_row_groups):
+ rg_metadata = metadata.row_group(row_group)
+ for column in range(rg_metadata.num_columns):
+ col_metadata = rg_metadata.column(column)
+ total_uncompressed_size += col_metadata.total_uncompressed_size
+ return total_uncompressed_size / (1024**2)
+
+
+def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
+ return hf_ds.data.nbytes // (1024**2)
+
+
+def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None:
+ if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0:
+ return None
+ return Path(hf_ds.cache_files[0]["filename"]).parents[2]
+
+
+def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
+ if file_idx == chunks_size - 1:
+ file_idx = 0
+ chunk_idx += 1
+ else:
+ file_idx += 1
+ return chunk_idx, file_idx
+
+
+def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
+ """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
+ Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
+ Concatenate all pyarrow references to return HF Dataset format
+
+ Args:
+ pq_dir: Directory containing parquet files
+ features: Optional features schema to ensure consistent loading of complex types like images
+ """
+ paths = sorted(pq_dir.glob("*/*.parquet"))
+ if len(paths) == 0:
+ raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
+
+ # TODO(rcadene): set num_proc to accelerate conversion to pyarrow
+ datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
+ return concatenate_datasets(datasets)
+
+
+def get_parquet_num_frames(parquet_path: str | Path) -> int:
+ metadata = pq.read_metadata(parquet_path)
+ return metadata.num_rows
+
+
+def get_video_size_in_mb(mp4_path: Path) -> float:
+ file_size_bytes = mp4_path.stat().st_size
+ file_size_mb = file_size_bytes / (1024**2)
+ return file_size_mb
+
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
- """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
+ """Flatten a nested dictionary by joining keys with a separator.
- For example:
- ```
- >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
- >>> print(flatten_dict(dct))
- {"a/b": 1, "a/c/d": 2, "e": 3}
+ Example:
+ >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
+ >>> print(flatten_dict(dct))
+ {'a/b': 1, 'a/c/d': 2, 'e': 3}
+
+ Args:
+ d (dict): The dictionary to flatten.
+ parent_key (str): The base key to prepend to the keys in this level.
+ sep (str): The separator to use between keys.
+
+ Returns:
+ dict: A flattened dictionary.
"""
items = []
for k, v in d.items():
@@ -95,6 +165,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
def unflatten_dict(d: dict, sep: str = "/") -> dict:
+ """Unflatten a dictionary with delimited keys into a nested dictionary.
+
+ Example:
+ >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
+ >>> print(unflatten_dict(flat_dct))
+ {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
+
+ Args:
+ d (dict): A dictionary with flattened keys.
+ sep (str): The separator used in the keys.
+
+ Returns:
+ dict: A nested dictionary.
+ """
outdict = {}
for key, value in d.items():
parts = key.split(sep)
@@ -107,26 +191,29 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
-def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
- split_keys = flattened_key.split(sep)
- getter = obj[split_keys[0]]
- if len(split_keys) == 1:
- return getter
-
- for key in split_keys[1:]:
- getter = getter[key]
-
- return getter
-
-
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
+ """Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible.
+
+ Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types.
+
+ Args:
+ stats (dict): A dictionary that may contain non-serializable numeric types.
+
+ Returns:
+ dict: A dictionary with all values converted to JSON-serializable types.
+
+ Raises:
+ NotImplementedError: If a value has an unsupported type.
+ """
serialized_dict = {}
for key, value in flatten_dict(stats).items():
- if isinstance(value, (torch.Tensor, np.ndarray)):
+ if isinstance(value, (torch.Tensor | np.ndarray)):
serialized_dict[key] = value.tolist()
+ elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
+ serialized_dict[key] = value
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
- elif isinstance(value, (int, float)):
+ elif isinstance(value, (int | float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
@@ -134,6 +221,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
+ """Embed image bytes into the dataset table before saving to Parquet.
+
+ This function prepares a Hugging Face dataset for serialization by converting
+ image objects into an embedded format that can be stored in Arrow/Parquet.
+
+ Args:
+ dataset (datasets.Dataset): The input dataset, possibly containing image features.
+
+ Returns:
+ datasets.Dataset: The dataset with images embedded in the table storage.
+ """
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
@@ -143,109 +241,151 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
def load_json(fpath: Path) -> Any:
+ """Load data from a JSON file.
+
+ Args:
+ fpath (Path): Path to the JSON file.
+
+ Returns:
+ Any: The data loaded from the JSON file.
+ """
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
+ """Write data to a JSON file.
+
+ Creates parent directories if they don't exist.
+
+ Args:
+ data (dict): The dictionary to write.
+ fpath (Path): The path to the output JSON file.
+ """
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
-def load_jsonlines(fpath: Path) -> list[Any]:
- with jsonlines.open(fpath, "r") as reader:
- return list(reader)
-
-
-def write_jsonlines(data: dict, fpath: Path) -> None:
- fpath.parent.mkdir(exist_ok=True, parents=True)
- with jsonlines.open(fpath, "w") as writer:
- writer.write_all(data)
-
-
-def append_jsonlines(data: dict, fpath: Path) -> None:
- fpath.parent.mkdir(exist_ok=True, parents=True)
- with jsonlines.open(fpath, "a") as writer:
- writer.write(data)
-
-
-def write_info(info: dict, local_dir: Path):
+def write_info(info: dict, local_dir: Path) -> None:
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
+ """Load dataset info metadata from its standard file path.
+
+ Also converts shape lists to tuples for consistency.
+
+ Args:
+ local_dir (Path): The root directory of the dataset.
+
+ Returns:
+ dict: The dataset information dictionary.
+ """
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
-def write_stats(stats: dict, local_dir: Path):
+def write_stats(stats: dict, local_dir: Path) -> None:
+ """Serialize and write dataset statistics to their standard file path.
+
+ Args:
+ stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
+ local_dir (Path): The root directory of the dataset.
+ """
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
-def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
+def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
+ """Recursively cast numerical values in a stats dictionary to numpy arrays.
+
+ Args:
+ stats (dict): The statistics dictionary.
+
+ Returns:
+ dict: The statistics dictionary with values cast to numpy arrays.
+ """
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
-def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
+def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None:
+ """Load dataset statistics and cast numerical values to numpy arrays.
+
+ Returns None if the stats file doesn't exist.
+
+ Args:
+ local_dir (Path): The root directory of the dataset.
+
+ Returns:
+ A dictionary of statistics or None if the file is not found.
+ """
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
-def write_task(task_index: int, task: dict, local_dir: Path):
- task_dict = {
- "task_index": task_index,
- "task": task,
- }
- append_jsonlines(task_dict, local_dir / TASKS_PATH)
+def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
+ path = local_dir / DEFAULT_TASKS_PATH
+ path.parent.mkdir(parents=True, exist_ok=True)
+ tasks.to_parquet(path)
-def load_tasks(local_dir: Path) -> tuple[dict, dict]:
- tasks = load_jsonlines(local_dir / TASKS_PATH)
- tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
- task_to_task_index = {task: task_index for task_index, task in tasks.items()}
- return tasks, task_to_task_index
+def load_tasks(local_dir: Path) -> pandas.DataFrame:
+ tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
+ return tasks
-def write_episode(episode: dict, local_dir: Path):
- append_jsonlines(episode, local_dir / EPISODES_PATH)
+def write_episodes(episodes: Dataset, local_dir: Path) -> None:
+ """Write episode metadata to a parquet file in the LeRobot v3.0 format.
+ This function writes episode-level metadata to a single parquet file.
+ Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures.
+
+ Args:
+ episodes: HuggingFace Dataset containing episode metadata
+ local_dir: Root directory where the dataset will be stored
+ """
+ episode_size_mb = get_hf_dataset_size_in_mb(episodes)
+ if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB:
+ raise NotImplementedError(
+ f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. "
+ f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. "
+ "This function only supports single-file episode metadata. "
+ )
+
+ fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
+ fpath.parent.mkdir(parents=True, exist_ok=True)
+ episodes.to_parquet(fpath)
-def load_episodes(local_dir: Path) -> dict:
- episodes = load_jsonlines(local_dir / EPISODES_PATH)
- return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
-
-
-def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
- # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
- # is a dictionary of stats and not an integer.
- episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
- append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
-
-
-def load_episodes_stats(local_dir: Path) -> dict:
- episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
- return {
- item["episode_index"]: cast_stats_to_numpy(item["stats"])
- for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
- }
-
-
-def backward_compatible_episodes_stats(
- stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
-) -> dict[str, dict[str, np.ndarray]]:
- return dict.fromkeys(episodes, stats)
+def load_episodes(local_dir: Path) -> datasets.Dataset:
+ episodes = load_nested_dataset(local_dir / EPISODES_DIR)
+ # Select episode features/columns containing references to episode data and videos
+ # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
+ # This is to speedup access to these data, instead of having to load episode stats.
+ episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
+ return episodes
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
+ """Load an image from a file into a numpy array.
+
+ Args:
+ fpath (str | Path): Path to the image file.
+ dtype (np.dtype): The desired data type of the output array. If floating,
+ pixels are scaled to [0, 1].
+ channel_first (bool): If True, converts the image to (C, H, W) format.
+ Otherwise, it remains in (H, W, C) format.
+
+ Returns:
+ np.ndarray: The image as a numpy array.
+ """
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
@@ -255,11 +395,20 @@ def load_image_as_numpy(
return img_array
-def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
- """Get a transform function that convert items from Hugging Face dataset (pyarrow)
- to torch tensors. Importantly, images are converted from PIL, which corresponds to
- a channel last representation (h w c) of uint8 type, to a torch image representation
- with channel first (c h w) of float32 type in range [0,1].
+def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
+ """Convert a batch from a Hugging Face dataset to torch tensors.
+
+ This transform function converts items from Hugging Face dataset format (pyarrow)
+ to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
+ to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
+ types are converted to torch.tensor.
+
+ Args:
+ items_dict (dict): A dictionary representing a batch of data from a
+ Hugging Face dataset.
+
+ Returns:
+ dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
@@ -274,6 +423,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
def is_valid_version(version: str) -> bool:
+ """Check if a string is a valid PEP 440 version.
+
+ Args:
+ version (str): The version string to check.
+
+ Returns:
+ bool: True if the version string is valid, False otherwise.
+ """
try:
packaging.version.parse(version)
return True
@@ -287,6 +444,18 @@ def check_version_compatibility(
current_version: str | packaging.version.Version,
enforce_breaking_major: bool = True,
) -> None:
+ """Check for version compatibility between a dataset and the current codebase.
+
+ Args:
+ repo_id (str): The repository ID for logging purposes.
+ version_to_check (str | packaging.version.Version): The version of the dataset.
+ current_version (str | packaging.version.Version): The current version of the codebase.
+ enforce_breaking_major (bool): If True, raise an error on major version mismatch.
+
+ Raises:
+ BackwardCompatibilityError: If the dataset version is from a newer, incompatible
+ major version of the codebase.
+ """
v_check = (
packaging.version.parse(version_to_check)
if not isinstance(version_to_check, packaging.version.Version)
@@ -300,11 +469,18 @@ def check_version_compatibility(
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
- logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
+ logging.warning(FUTURE_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
- """Returns available valid versions (branches and tags) on given repo."""
+ """Return available valid versions (branches and tags) on a given Hub repo.
+
+ Args:
+ repo_id (str): The repository ID on the Hugging Face Hub.
+
+ Returns:
+ list[packaging.version.Version]: A list of valid versions found.
+ """
api = HfApi()
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
@@ -317,9 +493,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
- """
- Returns the version if available on repo or the latest compatible one.
- Otherwise, will throw a `CompatibilityError`.
+ """Return the specified version if available on repo, or the latest compatible one.
+
+ If the exact version is not found, it looks for the latest version with the
+ same major version number that is less than or equal to the target minor version.
+
+ Args:
+ repo_id (str): The repository ID on the Hugging Face Hub.
+ version (str | packaging.version.Version): The target version.
+
+ Returns:
+ str: The safe version string (e.g., "v1.2.3") to use as a revision.
+
+ Raises:
+ RevisionNotFoundError: If the repo has no version tags.
+ BackwardCompatibilityError: If only older major versions are available.
+ ForwardCompatibilityError: If only newer major versions are available.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
@@ -361,6 +550,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
def get_hf_features_from_features(features: dict) -> datasets.Features:
+ """Convert a LeRobot features dictionary to a `datasets.Features` object.
+
+ Args:
+ features (dict): A LeRobot-style feature dictionary.
+
+ Returns:
+ datasets.Features: The corresponding Hugging Face `datasets.Features` object.
+
+ Raises:
+ ValueError: If a feature has an unsupported shape.
+ """
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
@@ -388,6 +588,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
def _validate_feature_names(features: dict[str, dict]) -> None:
+ """Validate that feature names do not contain invalid characters.
+
+ Args:
+ features (dict): The LeRobot features dictionary.
+
+ Raises:
+ ValueError: If any feature name contains '/'.
+ """
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
@@ -396,18 +604,38 @@ def _validate_feature_names(features: dict[str, dict]) -> None:
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
+ """Convert hardware-specific features to a LeRobot dataset feature dictionary.
+
+ This function takes a dictionary describing hardware outputs (like joint states
+ or camera image shapes) and formats it into the standard LeRobot feature
+ specification.
+
+ Args:
+ hw_features (dict): Dictionary mapping feature names to their type (float for
+ joints) or shape (tuple for images).
+ prefix (str): The prefix to add to the feature keys (e.g., "observation"
+ or "action").
+ use_video (bool): If True, image features are marked as "video", otherwise "image".
+
+ Returns:
+ dict: A LeRobot features dictionary.
+ """
features = {}
- joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
+ joint_fts = {
+ key: ftype
+ for key, ftype in hw_features.items()
+ if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
+ }
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
- if joint_fts and prefix == "action":
+ if joint_fts and prefix == ACTION:
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
- if joint_fts and prefix == "observation":
+ if joint_fts and prefix == OBS_STR:
features[f"{prefix}.state"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
@@ -428,6 +656,20 @@ def hw_to_dataset_features(
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
+ """Construct a single data frame from raw values based on dataset features.
+
+ A "frame" is a dictionary containing all the data for a single timestep,
+ formatted as numpy arrays according to the feature specification.
+
+ Args:
+ ds_features (dict): The LeRobot dataset features dictionary.
+ values (dict): A dictionary of raw values from the hardware/environment.
+ prefix (str): The prefix to filter features by (e.g., "observation"
+ or "action").
+
+ Returns:
+ dict: A dictionary representing a single frame of data.
+ """
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
@@ -440,17 +682,22 @@ def build_dataset_frame(
return frame
-def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
- camera_ft = {}
- if robot.cameras:
- camera_ft = {
- key: {"dtype": "video" if use_videos else "image", **ft}
- for key, ft in robot.camera_features.items()
- }
- return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
-
-
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
+ """Convert dataset features to policy features.
+
+ This function transforms the dataset's feature specification into a format
+ that a policy can use, classifying features by type (e.g., visual, state,
+ action) and ensuring correct shapes (e.g., channel-first for images).
+
+ Args:
+ features (dict): The LeRobot dataset features dictionary.
+
+ Returns:
+ dict: A dictionary mapping feature keys to `PolicyFeature` objects.
+
+ Raises:
+ ValueError: If an image feature does not have a 3D shape.
+ """
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
@@ -464,11 +711,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
- elif key == "observation.environment_state":
+ elif key == OBS_ENV_STATE:
type = FeatureType.ENV
- elif key.startswith("observation"):
+ elif key.startswith(OBS_STR):
type = FeatureType.STATE
- elif key.startswith("action"):
+ elif key.startswith(ACTION):
type = FeatureType.ACTION
else:
continue
@@ -481,123 +728,117 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
return policy_features
+def combine_feature_dicts(*dicts: dict) -> dict:
+ """Merge LeRobot grouped feature dicts.
+
+ - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
+ - For others (e.g. `observation.images.*`), the last one wins (if they are identical).
+
+ Args:
+ *dicts: A variable number of LeRobot feature dictionaries to merge.
+
+ Returns:
+ dict: A single merged feature dictionary.
+
+ Raises:
+ ValueError: If there's a dtype mismatch for a feature being merged.
+ """
+ out: dict = {}
+ for d in dicts:
+ for key, value in d.items():
+ if not isinstance(value, dict):
+ out[key] = value
+ continue
+
+ dtype = value.get("dtype")
+ shape = value.get("shape")
+ is_vector = (
+ dtype not in ("image", "video", "string")
+ and isinstance(shape, tuple)
+ and len(shape) == 1
+ and "names" in value
+ )
+
+ if is_vector:
+ # Initialize or retrieve the accumulating dict for this feature key
+ target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
+ # Ensure consistent data types across merged entries
+ if "dtype" in target and dtype != target["dtype"]:
+ raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
+
+ # Merge feature names: append only new ones to preserve order without duplicates
+ seen = set(target["names"])
+ for n in value["names"]:
+ if n not in seen:
+ target["names"].append(n)
+ seen.add(n)
+ # Recompute the shape to reflect the updated number of features
+ target["shape"] = (len(target["names"]),)
+ else:
+ # For images/videos and non-1D entries: override with the latest definition
+ out[key] = value
+ return out
+
+
def create_empty_dataset_info(
codebase_version: str,
fps: int,
features: dict,
use_videos: bool,
robot_type: str | None = None,
+ chunks_size: int | None = None,
+ data_files_size_in_mb: int | None = None,
+ video_files_size_in_mb: int | None = None,
) -> dict:
+ """Create a template dictionary for a new dataset's `info.json`.
+
+ Args:
+ codebase_version (str): The version of the LeRobot codebase.
+ fps (int): The frames per second of the data.
+ features (dict): The LeRobot features dictionary for the dataset.
+ use_videos (bool): Whether the dataset will store videos.
+ robot_type (str | None): The type of robot used, if any.
+
+ Returns:
+ dict: A dictionary with the initial dataset metadata.
+ """
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
- "total_videos": 0,
- "total_chunks": 0,
- "chunks_size": DEFAULT_CHUNK_SIZE,
+ "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
+ "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
+ "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"fps": fps,
"splits": {},
- "data_path": DEFAULT_PARQUET_PATH,
+ "data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
-def get_episode_data_index(
- episode_dicts: dict[dict], episodes: list[int] | None = None
-) -> dict[str, torch.Tensor]:
- episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
- if episodes is not None:
- episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
-
- cumulative_lengths = list(accumulate(episode_lengths.values()))
- return {
- "from": torch.LongTensor([0] + cumulative_lengths[:-1]),
- "to": torch.LongTensor(cumulative_lengths),
- }
-
-
-def check_timestamps_sync(
- timestamps: np.ndarray,
- episode_indices: np.ndarray,
- episode_data_index: dict[str, np.ndarray],
- fps: int,
- tolerance_s: float,
- raise_value_error: bool = True,
-) -> bool:
- """
- This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
- to account for possible numerical error.
-
- Args:
- timestamps (np.ndarray): Array of timestamps in seconds.
- episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
- episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
- which identifies indices for the end of each episode.
- fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
- tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
- raise_value_error (bool): Whether to raise a ValueError if the check fails.
-
- Returns:
- bool: True if all checked timestamp differences lie within tolerance, False otherwise.
-
- Raises:
- ValueError: If the check fails and `raise_value_error` is True.
- """
- if timestamps.shape != episode_indices.shape:
- raise ValueError(
- "timestamps and episode_indices should have the same shape. "
- f"Found {timestamps.shape=} and {episode_indices.shape=}."
- )
-
- # Consecutive differences
- diffs = np.diff(timestamps)
- within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
-
- # Mask to ignore differences at the boundaries between episodes
- mask = np.ones(len(diffs), dtype=bool)
- ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
- mask[ignored_diffs] = False
- filtered_within_tolerance = within_tolerance[mask]
-
- # Check if all remaining diffs are within tolerance
- if not np.all(filtered_within_tolerance):
- # Track original indices before masking
- original_indices = np.arange(len(diffs))
- filtered_indices = original_indices[mask]
- outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
- outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
-
- outside_tolerances = []
- for idx in outside_tolerance_indices:
- entry = {
- "timestamps": [timestamps[idx], timestamps[idx + 1]],
- "diff": diffs[idx],
- "episode_index": episode_indices[idx].item()
- if hasattr(episode_indices[idx], "item")
- else episode_indices[idx],
- }
- outside_tolerances.append(entry)
-
- if raise_value_error:
- raise ValueError(
- f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
- This might be due to synchronization issues during data collection.
- \n{pformat(outside_tolerances)}"""
- )
- return False
-
- return True
-
-
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
- """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
- This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
- actual timestamps from the dataset.
+ """Check if delta timestamps are multiples of 1/fps +/- tolerance.
+
+ This ensures that adding these delta timestamps to any existing timestamp in
+ the dataset will result in a value that aligns with the dataset's frame rate.
+
+ Args:
+ delta_timestamps (dict): A dictionary where values are lists of time
+ deltas in seconds.
+ fps (int): The frames per second of the dataset.
+ tolerance_s (float): The allowed tolerance in seconds.
+ raise_value_error (bool): If True, raises an error on failure.
+
+ Returns:
+ bool: True if all deltas are valid, False otherwise.
+
+ Raises:
+ ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
@@ -623,6 +864,15 @@ def check_delta_timestamps(
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
+ """Convert delta timestamps in seconds to delta indices in frames.
+
+ Args:
+ delta_timestamps (dict): A dictionary of time deltas in seconds.
+ fps (int): The frames per second of the dataset.
+
+ Returns:
+ dict: A dictionary of frame delta indices.
+ """
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
@@ -630,10 +880,18 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
return delta_indices
-def cycle(iterable):
- """The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
+def cycle(iterable: Any) -> Iterator[Any]:
+ """Create a dataloader-safe cyclical iterator.
- See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
+ This is an equivalent of `itertools.cycle` but is safe for use with
+ PyTorch DataLoaders with multiple workers.
+ See https://github.com/pytorch/pytorch/issues/23900 for details.
+
+ Args:
+ iterable: The iterable to cycle over.
+
+ Yields:
+ Items from the iterable, restarting from the beginning when exhausted.
"""
iterator = iter(iterable)
while True:
@@ -643,9 +901,15 @@ def cycle(iterable):
iterator = iter(iterable)
-def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
- """Create a branch on a existing Hugging Face repo. Delete the branch if it already
- exists before creating it.
+def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None:
+ """Create a branch on an existing Hugging Face repo.
+
+ Deletes the branch if it already exists before creating it.
+
+ Args:
+ repo_id (str): The ID of the repository.
+ branch (str): The name of the branch to create.
+ repo_type (str | None): The type of the repository (e.g., "dataset").
"""
api = HfApi()
@@ -663,9 +927,20 @@ def create_lerobot_dataset_card(
dataset_info: dict | None = None,
**kwargs,
) -> DatasetCard:
- """
- Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md.
- Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
+ """Create a `DatasetCard` for a LeRobot dataset.
+
+ Keyword arguments are used to replace values in the card template.
+ Note: If specified, `license` must be a valid license identifier from
+ https://huggingface.co/docs/hub/repositories-licenses.
+
+ Args:
+ tags (list | None): A list of tags to add to the dataset card.
+ dataset_info (dict | None): The dataset's info dictionary, which will
+ be displayed on the card.
+ **kwargs: Additional keyword arguments to populate the card template.
+
+ Returns:
+ DatasetCard: The generated dataset card object.
"""
card_tags = ["LeRobot"]
@@ -696,76 +971,37 @@ def create_lerobot_dataset_card(
)
-class IterableNamespace(SimpleNamespace):
- """
- A namespace object that supports both dictionary-like iteration and dot notation access.
- Automatically converts nested dictionaries into IterableNamespaces.
-
- This class extends SimpleNamespace to provide:
- - Dictionary-style iteration over keys
- - Access to items via both dot notation (obj.key) and brackets (obj["key"])
- - Dictionary-like methods: items(), keys(), values()
- - Recursive conversion of nested dictionaries
-
- Args:
- dictionary: Optional dictionary to initialize the namespace
- **kwargs: Additional keyword arguments passed to SimpleNamespace
-
- Examples:
- >>> data = {"name": "Alice", "details": {"age": 25}}
- >>> ns = IterableNamespace(data)
- >>> ns.name
- 'Alice'
- >>> ns.details.age
- 25
- >>> list(ns.keys())
- ['name', 'details']
- >>> for key, value in ns.items():
- ... print(f"{key}: {value}")
- name: Alice
- details: IterableNamespace(age=25)
- """
-
- def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
- super().__init__(**kwargs)
- if dictionary is not None:
- for key, value in dictionary.items():
- if isinstance(value, dict):
- setattr(self, key, IterableNamespace(value))
- else:
- setattr(self, key, value)
-
- def __iter__(self) -> Iterator[str]:
- return iter(vars(self))
-
- def __getitem__(self, key: str) -> Any:
- return vars(self)[key]
-
- def items(self):
- return vars(self).items()
-
- def values(self):
- return vars(self).values()
-
- def keys(self):
- return vars(self).keys()
-
-
-def validate_frame(frame: dict, features: dict):
+def validate_frame(frame: dict, features: dict) -> None:
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
- error_message = validate_features_presence(actual_features, expected_features)
+ # task is a special required field that's not part of regular features
+ if "task" not in actual_features:
+ raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
- common_features = actual_features & expected_features
- for name in common_features - {"task"}:
+ # Remove task from actual_features for regular feature validation
+ actual_features_for_validation = actual_features - {"task"}
+
+ error_message = validate_features_presence(actual_features_for_validation, expected_features)
+
+ common_features = actual_features_for_validation & expected_features
+ for name in common_features:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
-def validate_features_presence(actual_features: set[str], expected_features: set[str]):
+def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
+ """Check for missing or extra features in a frame.
+
+ Args:
+ actual_features (set[str]): The set of feature names present in the frame.
+ expected_features (set[str]): The set of feature names expected in the frame.
+
+ Returns:
+ str: An error message string if there's a mismatch, otherwise an empty string.
+ """
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - expected_features
@@ -780,7 +1016,22 @@ def validate_features_presence(actual_features: set[str], expected_features: set
return error_message
-def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
+def validate_feature_dtype_and_shape(
+ name: str, feature: dict, value: np.ndarray | PILImage.Image | str
+) -> str:
+ """Validate the dtype and shape of a single feature's value.
+
+ Args:
+ name (str): The name of the feature.
+ feature (dict): The feature specification from the LeRobot features dictionary.
+ value: The value of the feature to validate.
+
+ Returns:
+ str: An error message if validation fails, otherwise an empty string.
+
+ Raises:
+ NotImplementedError: If the feature dtype is not supported for validation.
+ """
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
@@ -795,7 +1046,18 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
-):
+) -> str:
+ """Validate a feature that is expected to be a numpy array.
+
+ Args:
+ name (str): The name of the feature.
+ expected_dtype (str): The expected numpy dtype as a string.
+ expected_shape (list[int]): The expected shape.
+ value (np.ndarray): The numpy array to validate.
+
+ Returns:
+ str: An error message if validation fails, otherwise an empty string.
+ """
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
@@ -812,7 +1074,21 @@ def validate_feature_numpy_array(
return error_message
-def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
+def validate_feature_image_or_video(
+ name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
+) -> str:
+ """Validate a feature that is expected to be an image or video frame.
+
+ Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
+
+ Args:
+ name (str): The name of the feature.
+ expected_shape (list[str]): The expected shape (C, H, W).
+ value: The image data to validate.
+
+ Returns:
+ str: An error message if validation fails, otherwise an empty string.
+ """
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
@@ -828,13 +1104,36 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
return error_message
-def validate_feature_string(name: str, value: str):
+def validate_feature_string(name: str, value: str) -> str:
+ """Validate a feature that is expected to be a string.
+
+ Args:
+ name (str): The name of the feature.
+ value (str): The value to validate.
+
+ Returns:
+ str: An error message if validation fails, otherwise an empty string.
+ """
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
-def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
+def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
+ """Validate the episode buffer before it's written to disk.
+
+ Ensures the buffer has the required keys, contains at least one frame, and
+ has features consistent with the dataset's specification.
+
+ Args:
+ episode_buffer (dict): The buffer containing data for a single episode.
+ total_episodes (int): The current total number of episodes in the dataset.
+ features (dict): The LeRobot features dictionary for the dataset.
+
+ Raises:
+ ValueError: If the buffer is invalid.
+ NotImplementedError: If the episode index is manually set and doesn't match.
+ """
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
@@ -858,3 +1157,207 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
+
+
+def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
+ """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
+ This way, it can be loaded by HF dataset and correctly formatted images are returned.
+ """
+ # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
+ datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
+
+
+def item_to_torch(item: dict) -> dict:
+ """Convert all items in a dictionary to PyTorch tensors where appropriate.
+
+ This function is used to convert an item from a streaming dataset to PyTorch tensors.
+
+ Args:
+ item (dict): Dictionary of items from a dataset.
+
+ Returns:
+ dict: Dictionary with all tensor-like items converted to torch.Tensor.
+ """
+ for key, val in item.items():
+ if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
+ # Convert numpy arrays and lists to torch tensors
+ item[key] = torch.tensor(val)
+ return item
+
+
+def is_float_in_list(target, float_list, threshold=1e-6):
+ return any(abs(target - x) <= threshold for x in float_list)
+
+
+def find_float_index(target, float_list, threshold=1e-6):
+ for i, x in enumerate(float_list):
+ if abs(target - x) <= threshold:
+ return i
+ return -1
+
+
+class LookBackError(Exception):
+ """
+ Exception raised when trying to look back in the history of a Backtrackable object.
+ """
+
+ pass
+
+
+class LookAheadError(Exception):
+ """
+ Exception raised when trying to look ahead in the future of a Backtrackable object.
+ """
+
+ pass
+
+
+class Backtrackable(Generic[T]):
+ """
+ Wrap any iterator/iterable so you can step back up to `history` items
+ and look ahead up to `lookahead` items.
+
+ This is useful for streaming datasets where you need to access previous and future items
+ but can't load the entire dataset into memory.
+
+ Example:
+ -------
+ ```python
+ ds = load_dataset("c4", "en", streaming=True, split="train")
+ rev = Backtrackable(ds, history=3, lookahead=2)
+
+ x0 = next(rev) # forward
+ x1 = next(rev)
+ x2 = next(rev)
+
+ # Look ahead
+ x3_peek = rev.peek_ahead(1) # next item without moving cursor
+ x4_peek = rev.peek_ahead(2) # two items ahead
+
+ # Look back
+ x1_again = rev.peek_back(1) # previous item without moving cursor
+ x0_again = rev.peek_back(2) # two items back
+
+ # Move backward
+ x1_back = rev.prev() # back one step
+ next(rev) # returns x2, continues forward from where we were
+ ```
+ """
+
+ __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
+
+ def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
+ if history < 1:
+ raise ValueError("history must be >= 1")
+ if lookahead <= 0:
+ raise ValueError("lookahead must be > 0")
+
+ self._source: Iterator[T] = iter(iterable)
+ self._back_buf: deque[T] = deque(maxlen=history)
+ self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
+ self._cursor: int = 0
+ self._history = history
+ self._lookahead = lookahead
+
+ def __iter__(self) -> "Backtrackable[T]":
+ return self
+
+ def __next__(self) -> T:
+ # If we've stepped back, consume from back buffer first
+ if self._cursor < 0: # -1 means "last item", etc.
+ self._cursor += 1
+ return self._back_buf[self._cursor]
+
+ # If we have items in the ahead buffer, use them first
+ item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
+
+ # Add current item to back buffer and reset cursor
+ self._back_buf.append(item)
+ self._cursor = 0
+ return item
+
+ def prev(self) -> T:
+ """
+ Step one item back in history and return it.
+ Raises IndexError if already at the oldest buffered item.
+ """
+ if len(self._back_buf) + self._cursor <= 1:
+ raise LookBackError("At start of history")
+
+ self._cursor -= 1
+ return self._back_buf[self._cursor]
+
+ def peek_back(self, n: int = 1) -> T:
+ """
+ Look `n` items back (n=1 == previous item) without moving the cursor.
+ """
+ if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
+ raise LookBackError("peek_back distance out of range")
+
+ return self._back_buf[self._cursor - (n + 1)]
+
+ def peek_ahead(self, n: int = 1) -> T:
+ """
+ Look `n` items ahead (n=1 == next item) without moving the cursor.
+ Fills the ahead buffer if necessary.
+ """
+ if n < 1:
+ raise LookAheadError("peek_ahead distance must be 1 or more")
+ elif n > self._lookahead:
+ raise LookAheadError("peek_ahead distance exceeds lookahead limit")
+
+ # Fill ahead buffer if we don't have enough items
+ while len(self._ahead_buf) < n:
+ try:
+ item = next(self._source)
+ self._ahead_buf.append(item)
+
+ except StopIteration as err:
+ raise LookAheadError("peek_ahead: not enough items in source") from err
+
+ return self._ahead_buf[n - 1]
+
+ def history(self) -> list[T]:
+ """
+ Return a copy of the buffered history (most recent last).
+ The list length ≤ `history` argument passed at construction.
+ """
+ if self._cursor == 0:
+ return list(self._back_buf)
+
+ # When cursor<0, slice so the order remains chronological
+ return list(self._back_buf)[: self._cursor or None]
+
+ def can_peek_back(self, steps: int = 1) -> bool:
+ """
+ Check if we can go back `steps` items without raising an IndexError.
+ """
+ return steps <= len(self._back_buf) + self._cursor
+
+ def can_peek_ahead(self, steps: int = 1) -> bool:
+ """
+ Check if we can peek ahead `steps` items.
+ This may involve trying to fill the ahead buffer.
+ """
+ if self._lookahead > 0 and steps > self._lookahead:
+ return False
+
+ # Try to fill ahead buffer to check if we can peek that far
+ try:
+ while len(self._ahead_buf) < steps:
+ if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
+ return False
+ item = next(self._source)
+ self._ahead_buf.append(item)
+ return True
+ except StopIteration:
+ return False
+
+
+def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
+ """
+ Safe shards the dataset.
+ """
+ shard_idx = min(dataset.num_shards, index + 1) - 1
+
+ return dataset.shard(num_shards, index=shard_idx)
diff --git a/src/lerobot/datasets/v2/batch_convert_dataset_v1_to_v2.py b/src/lerobot/datasets/v2/batch_convert_dataset_v1_to_v2.py
deleted file mode 100644
index fa99c725e..000000000
--- a/src/lerobot/datasets/v2/batch_convert_dataset_v1_to_v2.py
+++ /dev/null
@@ -1,884 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
-
-Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
-lerobot/configs/robot/aloha.yaml before running this script.
-"""
-
-import traceback
-from pathlib import Path
-from textwrap import dedent
-
-from lerobot import available_datasets
-from lerobot.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
-from lerobot.robots.aloha.configuration_aloha import AlohaRobotConfig
-
-LOCAL_DIR = Path("data/")
-
-# spellchecker:off
-ALOHA_MOBILE_INFO = {
- "robot_config": AlohaRobotConfig(),
- "license": "mit",
- "url": "https://mobile-aloha.github.io/",
- "paper": "https://huggingface.co/papers/2401.02117",
- "citation_bibtex": dedent(r"""
- @inproceedings{fu2024mobile,
- author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
- title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
- booktitle = {arXiv},
- year = {2024},
- }""").lstrip(),
-}
-ALOHA_STATIC_INFO = {
- "robot_config": AlohaRobotConfig(),
- "license": "mit",
- "url": "https://tonyzhaozh.github.io/aloha/",
- "paper": "https://huggingface.co/papers/2304.13705",
- "citation_bibtex": dedent(r"""
- @article{Zhao2023LearningFB,
- title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
- author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
- journal={RSS},
- year={2023},
- volume={abs/2304.13705},
- url={https://huggingface.co/papers/2304.13705}
- }""").lstrip(),
-}
-PUSHT_INFO = {
- "license": "mit",
- "url": "https://diffusion-policy.cs.columbia.edu/",
- "paper": "https://huggingface.co/papers/2303.04137",
- "citation_bibtex": dedent(r"""
- @article{chi2024diffusionpolicy,
- author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
- title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
- journal = {The International Journal of Robotics Research},
- year = {2024},
- }""").lstrip(),
-}
-XARM_INFO = {
- "license": "mit",
- "url": "https://www.nicklashansen.com/td-mpc/",
- "paper": "https://huggingface.co/papers/2203.04955",
- "citation_bibtex": dedent(r"""
- @inproceedings{Hansen2022tdmpc,
- title={Temporal Difference Learning for Model Predictive Control},
- author={Nicklas Hansen and Xiaolong Wang and Hao Su},
- booktitle={ICML},
- year={2022}
- }
- """),
-}
-UNITREEH_INFO = {
- "license": "apache-2.0",
-}
-
-DATASETS = {
- "aloha_mobile_cabinet": {
- "single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
- **ALOHA_MOBILE_INFO,
- },
- "aloha_mobile_chair": {
- "single_task": "Push the chairs in front of the desk to place them against it.",
- **ALOHA_MOBILE_INFO,
- },
- "aloha_mobile_elevator": {
- "single_task": "Take the elevator to the 1st floor.",
- **ALOHA_MOBILE_INFO,
- },
- "aloha_mobile_shrimp": {
- "single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
- **ALOHA_MOBILE_INFO,
- },
- "aloha_mobile_wash_pan": {
- "single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
- **ALOHA_MOBILE_INFO,
- },
- "aloha_mobile_wipe_wine": {
- "single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
- **ALOHA_MOBILE_INFO,
- },
- "aloha_static_battery": {
- "single_task": "Place the battery into the slot of the remote controller.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
- "aloha_static_coffee": {
- "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_coffee_new": {
- "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_cups_open": {
- "single_task": "Pick up the plastic cup and open its lid.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_fork_pick_up": {
- "single_task": "Pick up the fork and place it on the plate.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_pingpong_test": {
- "single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_pro_pencil": {
- "single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_screw_driver": {
- "single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_tape": {
- "single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_thread_velcro": {
- "single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_towel": {
- "single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_vinh_cup": {
- "single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_vinh_cup_left": {
- "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
- "aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
- "aloha_sim_insertion_scripted_image": {
- "single_task": "Insert the peg into the socket.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
- "aloha_sim_insertion_human_image": {
- "single_task": "Insert the peg into the socket.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_sim_transfer_cube_scripted": {
- "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_sim_transfer_cube_scripted_image": {
- "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_sim_transfer_cube_human": {
- "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
- **ALOHA_STATIC_INFO,
- },
- "aloha_sim_transfer_cube_human_image": {
- "single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
- **ALOHA_STATIC_INFO,
- },
- "pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
- "pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
- "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
- "unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
- "unitreeh1_two_robot_greeting": {
- "single_task": "Greet the other robot with a high five.",
- **UNITREEH_INFO,
- },
- "unitreeh1_warehouse": {
- "single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
- **UNITREEH_INFO,
- },
- "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
- "xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
- "xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
- "xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
- "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
- "xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
- "xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
- "xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
- "umi_cup_in_the_wild": {
- "single_task": "Put the cup on the plate.",
- "license": "apache-2.0",
- },
- "asu_table_top": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
- "citation_bibtex": dedent(r"""
- @inproceedings{zhou2023modularity,
- title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
- author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
- booktitle={Conference on Robot Learning},
- pages={1684--1695},
- year={2023},
- organization={PMLR}
- }
- @article{zhou2023learning,
- title={Learning modular language-conditioned robot policies through attention},
- author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
- journal={Autonomous Robots},
- pages={1--21},
- year={2023},
- publisher={Springer}
- }""").lstrip(),
- },
- "austin_buds_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://ut-austin-rpl.github.io/BUDS-website/",
- "paper": "https://huggingface.co/papers/2109.13841",
- "citation_bibtex": dedent(r"""
- @article{zhu2022bottom,
- title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
- author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
- journal={IEEE Robotics and Automation Letters},
- volume={7},
- number={2},
- pages={4126--4133},
- year={2022},
- publisher={IEEE}
- }""").lstrip(),
- },
- "austin_sailor_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://ut-austin-rpl.github.io/sailor/",
- "paper": "https://huggingface.co/papers/2210.11435",
- "citation_bibtex": dedent(r"""
- @inproceedings{nasiriany2022sailor,
- title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
- author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
- booktitle={Conference on Robot Learning (CoRL)},
- year={2022}
- }""").lstrip(),
- },
- "austin_sirius_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://ut-austin-rpl.github.io/sirius/",
- "paper": "https://huggingface.co/papers/2211.08416",
- "citation_bibtex": dedent(r"""
- @inproceedings{liu2022robot,
- title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
- author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
- booktitle = {Robotics: Science and Systems (RSS)},
- year = {2023}
- }""").lstrip(),
- },
- "berkeley_autolab_ur5": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "url": "https://sites.google.com/view/berkeley-ur5/home",
- "citation_bibtex": dedent(r"""
- @misc{BerkeleyUR5Website,
- title = {Berkeley {UR5} Demonstration Dataset},
- author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
- howpublished = {https://sites.google.com/view/berkeley-ur5/home},
- }""").lstrip(),
- },
- "berkeley_cable_routing": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "url": "https://sites.google.com/view/cablerouting/home",
- "paper": "https://huggingface.co/papers/2307.08927",
- "citation_bibtex": dedent(r"""
- @article{luo2023multistage,
- author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
- title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
- journal = {arXiv pre-print},
- year = {2023},
- url = {https://huggingface.co/papers/2307.08927},
- }""").lstrip(),
- },
- "berkeley_fanuc_manipulation": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
- "citation_bibtex": dedent(r"""
- @article{fanuc_manipulation2023,
- title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
- author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
- year={2023},
- }""").lstrip(),
- },
- "berkeley_gnm_cory_hall": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "paper": "https://huggingface.co/papers/1709.10489",
- "citation_bibtex": dedent(r"""
- @inproceedings{kahn2018self,
- title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
- author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
- booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
- pages={5129--5136},
- year={2018},
- organization={IEEE}
- }""").lstrip(),
- },
- "berkeley_gnm_recon": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://sites.google.com/view/recon-robot",
- "paper": "https://huggingface.co/papers/2104.05859",
- "citation_bibtex": dedent(r"""
- @inproceedings{shah2021rapid,
- title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
- author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
- booktitle={5th Annual Conference on Robot Learning },
- year={2021},
- url={https://openreview.net/forum?id=d_SWJhyKfVw}
- }""").lstrip(),
- },
- "berkeley_gnm_sac_son": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://sites.google.com/view/SACSoN-review",
- "paper": "https://huggingface.co/papers/2306.01874",
- "citation_bibtex": dedent(r"""
- @article{hirose2023sacson,
- title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
- author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
- journal={arXiv preprint arXiv:2306.01874},
- year={2023}
- }""").lstrip(),
- },
- "berkeley_mvp": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "paper": "https://huggingface.co/papers/2203.06173",
- "citation_bibtex": dedent(r"""
- @InProceedings{Radosavovic2022,
- title = {Real-World Robot Learning with Masked Visual Pre-training},
- author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
- booktitle = {CoRL},
- year = {2022}
- }""").lstrip(),
- },
- "berkeley_rpt": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "paper": "https://huggingface.co/papers/2306.10007",
- "citation_bibtex": dedent(r"""
- @article{Radosavovic2023,
- title={Robot Learning with Sensorimotor Pre-training},
- author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
- year={2023},
- journal={arXiv:2306.10007}
- }""").lstrip(),
- },
- "cmu_franka_exploration_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://human-world-model.github.io/",
- "paper": "https://huggingface.co/papers/2308.10901",
- "citation_bibtex": dedent(r"""
- @inproceedings{mendonca2023structured,
- title={Structured World Models from Human Videos},
- author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
- journal={RSS},
- year={2023}
- }""").lstrip(),
- },
- "cmu_play_fusion": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://play-fusion.github.io/",
- "paper": "https://huggingface.co/papers/2312.04549",
- "citation_bibtex": dedent(r"""
- @inproceedings{chen2023playfusion,
- title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
- author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
- booktitle={CoRL},
- year={2023}
- }""").lstrip(),
- },
- "cmu_stretch": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://robo-affordances.github.io/",
- "paper": "https://huggingface.co/papers/2304.08488",
- "citation_bibtex": dedent(r"""
- @inproceedings{bahl2023affordances,
- title={Affordances from Human Videos as a Versatile Representation for Robotics},
- author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
- booktitle={CVPR},
- year={2023}
- }
- @article{mendonca2023structured,
- title={Structured World Models from Human Videos},
- author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
- journal={CoRL},
- year={2023}
- }""").lstrip(),
- },
- "columbia_cairlab_pusht_real": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://diffusion-policy.cs.columbia.edu/",
- "paper": "https://huggingface.co/papers/2303.04137",
- "citation_bibtex": dedent(r"""
- @inproceedings{chi2023diffusionpolicy,
- title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
- author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
- booktitle={Proceedings of Robotics: Science and Systems (RSS)},
- year={2023}
- }""").lstrip(),
- },
- "conq_hose_manipulation": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
- "citation_bibtex": dedent(r"""
- @misc{ConqHoseManipData,
- author={Peter Mitrano and Dmitry Berenson},
- title={Conq Hose Manipulation Dataset, v1.15.0},
- year={2024},
- howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
- }""").lstrip(),
- },
- "dlr_edan_shared_control": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "paper": "https://ieeexplore.ieee.org/document/9341156",
- "citation_bibtex": dedent(r"""
- @inproceedings{vogel_edan_2020,
- title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
- language = {en},
- booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
- author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
- year = {2020}
- }
- @inproceedings{quere_shared_2020,
- address = {Paris, France},
- title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
- language = {en},
- booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
- author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
- year = {2020},
- pages = {7},
- }""").lstrip(),
- },
- "dlr_sara_grid_clamp": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "paper": "https://www.researchsquare.com/article/rs-3289569/v1",
- "citation_bibtex": dedent(r"""
- @article{padalkar2023guided,
- title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
- author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
- journal={Research square preprint rs-3289569/v1},
- year={2023}
- }""").lstrip(),
- },
- "dlr_sara_pour": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
- "citation_bibtex": dedent(r"""
- @inproceedings{padalkar2023guiding,
- title={Guiding Reinforcement Learning with Shared Control Templates},
- author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
- booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
- year={2023},
- organization={IEEE}
- }""").lstrip(),
- },
- "droid_100": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://droid-dataset.github.io/",
- "paper": "https://huggingface.co/papers/2403.12945",
- "citation_bibtex": dedent(r"""
- @article{khazatsky2024droid,
- title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
- author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
- year = {2024},
- }""").lstrip(),
- },
- "fmb": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "url": "https://functional-manipulation-benchmark.github.io/",
- "paper": "https://huggingface.co/papers/2401.08553",
- "citation_bibtex": dedent(r"""
- @article{luo2024fmb,
- title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
- author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
- journal={arXiv preprint arXiv:2401.08553},
- year={2024}
- }""").lstrip(),
- },
- "iamlab_cmu_pickup_insert": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://openreview.net/forum?id=WuBv9-IGDUA",
- "paper": "https://huggingface.co/papers/2401.14502",
- "citation_bibtex": dedent(r"""
- @inproceedings{saxena2023multiresolution,
- title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
- author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
- booktitle={7th Annual Conference on Robot Learning},
- year={2023},
- url={https://openreview.net/forum?id=WuBv9-IGDUA}
- }""").lstrip(),
- },
- "imperialcollege_sawyer_wrist_cam": {
- "tasks_col": "language_instruction",
- "license": "mit",
- },
- "jaco_play": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "url": "https://github.com/clvrai/clvr_jaco_play_dataset",
- "citation_bibtex": dedent(r"""
- @software{dass2023jacoplay,
- author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
- and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
- title = {CLVR Jaco Play Dataset},
- url = {https://github.com/clvrai/clvr_jaco_play_dataset},
- version = {1.0.0},
- year = {2023}
- }""").lstrip(),
- },
- "kaist_nonprehensile": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
- "citation_bibtex": dedent(r"""
- @article{kimpre,
- title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
- author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
- booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
- year={2023},
- organization={IEEE}
- }""").lstrip(),
- },
- "nyu_door_opening_surprising_effectiveness": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://jyopari.github.io/VINN/",
- "paper": "https://huggingface.co/papers/2112.01511",
- "citation_bibtex": dedent(r"""
- @misc{pari2021surprising,
- title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
- author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
- year={2021},
- eprint={2112.01511},
- archivePrefix={arXiv},
- primaryClass={cs.RO}
- }""").lstrip(),
- },
- "nyu_franka_play_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://play-to-policy.github.io/",
- "paper": "https://huggingface.co/papers/2210.10047",
- "citation_bibtex": dedent(r"""
- @article{cui2022play,
- title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
- author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
- journal = {arXiv preprint arXiv:2210.10047},
- year = {2022}
- }""").lstrip(),
- },
- "nyu_rot_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://rot-robot.github.io/",
- "paper": "https://huggingface.co/papers/2206.15469",
- "citation_bibtex": dedent(r"""
- @inproceedings{haldar2023watch,
- title={Watch and match: Supercharging imitation with regularized optimal transport},
- author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
- booktitle={Conference on Robot Learning},
- pages={32--43},
- year={2023},
- organization={PMLR}
- }""").lstrip(),
- },
- "roboturk": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://roboturk.stanford.edu/dataset_real.html",
- "paper": "PAPER",
- "citation_bibtex": dedent(r"""
- @inproceedings{mandlekar2019scaling,
- title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
- author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
- booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
- pages={1048--1055},
- year={2019},
- organization={IEEE}
- }""").lstrip(),
- },
- "stanford_hydra_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://sites.google.com/view/hydra-il-2023",
- "paper": "https://huggingface.co/papers/2306.17237",
- "citation_bibtex": dedent(r"""
- @article{belkhale2023hydra,
- title={HYDRA: Hybrid Robot Actions for Imitation Learning},
- author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
- journal={arxiv},
- year={2023}
- }""").lstrip(),
- },
- "stanford_kuka_multimodal_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://sites.google.com/view/visionandtouch",
- "paper": "https://huggingface.co/papers/1810.10191",
- "citation_bibtex": dedent(r"""
- @inproceedings{lee2019icra,
- title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
- author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
- booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
- year={2019},
- url={https://huggingface.co/papers/1810.10191}
- }""").lstrip(),
- },
- "stanford_robocook": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://hshi74.github.io/robocook/",
- "paper": "https://huggingface.co/papers/2306.14447",
- "citation_bibtex": dedent(r"""
- @article{shi2023robocook,
- title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
- author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
- journal={arXiv preprint arXiv:2306.14447},
- year={2023}
- }""").lstrip(),
- },
- "taco_play": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
- "paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911",
- "citation_bibtex": dedent(r"""
- @inproceedings{rosete2022tacorl,
- author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
- title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
- journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
- year = {2022}
- }
- @inproceedings{mees23hulc2,
- title={Grounding Language with Visual Affordances over Unstructured Data},
- author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
- booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
- year={2023},
- address = {London, UK}
- }""").lstrip(),
- },
- "tokyo_u_lsmo": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "URL",
- "paper": "https://huggingface.co/papers/2107.05842",
- "citation_bibtex": dedent(r"""
- @Article{Osa22,
- author = {Takayuki Osa},
- journal = {The International Journal of Robotics Research},
- title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
- year = {2022},
- number = {3},
- pages = {291--311},
- volume = {41},
- }""").lstrip(),
- },
- "toto": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://toto-benchmark.org/",
- "paper": "https://huggingface.co/papers/2306.00942",
- "citation_bibtex": dedent(r"""
- @inproceedings{zhou2023train,
- author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
- booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
- title={Train Offline, Test Online: A Real Robot Learning Benchmark},
- year={2023},
- }""").lstrip(),
- },
- "ucsd_kitchen_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "citation_bibtex": dedent(r"""
- @ARTICLE{ucsd_kitchens,
- author = {Ge Yan, Kris Wu, and Xiaolong Wang},
- title = {{ucsd kitchens Dataset}},
- year = {2023},
- month = {August}
- }""").lstrip(),
- },
- "ucsd_pick_and_place_dataset": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://owmcorl.github.io/#",
- "paper": "https://huggingface.co/papers/2310.16029",
- "citation_bibtex": dedent(r"""
- @preprint{Feng2023Finetuning,
- title={Finetuning Offline World Models in the Real World},
- author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
- year={2023}
- }""").lstrip(),
- },
- "uiuc_d3field": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://robopil.github.io/d3fields/",
- "paper": "https://huggingface.co/papers/2309.16118",
- "citation_bibtex": dedent(r"""
- @article{wang2023d3field,
- title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
- author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
- journal={arXiv preprint arXiv:},
- year={2023},
- }""").lstrip(),
- },
- "usc_cloth_sim": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://uscresl.github.io/dmfd/",
- "paper": "https://huggingface.co/papers/2207.10148",
- "citation_bibtex": dedent(r"""
- @article{salhotra2022dmfd,
- author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
- journal={IEEE Robotics and Automation Letters},
- title={Learning Deformable Object Manipulation From Expert Demonstrations},
- year={2022},
- volume={7},
- number={4},
- pages={8775-8782},
- doi={10.1109/LRA.2022.3187843}
- }""").lstrip(),
- },
- "utaustin_mutex": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://ut-austin-rpl.github.io/MUTEX/",
- "paper": "https://huggingface.co/papers/2309.14320",
- "citation_bibtex": dedent(r"""
- @inproceedings{shah2023mutex,
- title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
- author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
- booktitle={7th Annual Conference on Robot Learning},
- year={2023},
- url={https://openreview.net/forum?id=PwqiqaaEzJ}
- }""").lstrip(),
- },
- "utokyo_pr2_opening_fridge": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "citation_bibtex": dedent(r"""
- @misc{oh2023pr2utokyodatasets,
- author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
- title={X-Embodiment U-Tokyo PR2 Datasets},
- year={2023},
- url={https://github.com/ojh6404/rlds_dataset_builder},
- }""").lstrip(),
- },
- "utokyo_pr2_tabletop_manipulation": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "citation_bibtex": dedent(r"""
- @misc{oh2023pr2utokyodatasets,
- author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
- title={X-Embodiment U-Tokyo PR2 Datasets},
- year={2023},
- url={https://github.com/ojh6404/rlds_dataset_builder},
- }""").lstrip(),
- },
- "utokyo_saytap": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://saytap.github.io/",
- "paper": "https://huggingface.co/papers/2306.07580",
- "citation_bibtex": dedent(r"""
- @article{saytap2023,
- author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
- Tatsuya Harada},
- title = {SayTap: Language to Quadrupedal Locomotion},
- eprint = {arXiv:2306.07580},
- url = {https://saytap.github.io},
- note = {https://saytap.github.io},
- year = {2023}
- }""").lstrip(),
- },
- "utokyo_xarm_bimanual": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "citation_bibtex": dedent(r"""
- @misc{matsushima2023weblab,
- title={Weblab xArm Dataset},
- author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
- year={2023},
- }""").lstrip(),
- },
- "utokyo_xarm_pick_and_place": {
- "tasks_col": "language_instruction",
- "license": "cc-by-4.0",
- "citation_bibtex": dedent(r"""
- @misc{matsushima2023weblab,
- title={Weblab xArm Dataset},
- author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
- year={2023},
- }""").lstrip(),
- },
- "viola": {
- "tasks_col": "language_instruction",
- "license": "mit",
- "url": "https://ut-austin-rpl.github.io/VIOLA/",
- "paper": "https://huggingface.co/papers/2210.11339",
- "citation_bibtex": dedent(r"""
- @article{zhu2022viola,
- title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
- author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
- journal={6th Annual Conference on Robot Learning (CoRL)},
- year={2022}
- }""").lstrip(),
- },
-}
-# spellchecker:on
-
-
-def batch_convert():
- status = {}
- logfile = LOCAL_DIR / "conversion_log.txt"
- assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
- for num, (name, kwargs) in enumerate(DATASETS.items()):
- repo_id = f"lerobot/{name}"
- print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
- print("---------------------------------------------------------")
- try:
- convert_dataset(repo_id, LOCAL_DIR, **kwargs)
- status = f"{repo_id}: success."
- with open(logfile, "a") as file:
- file.write(status + "\n")
- except Exception:
- status = f"{repo_id}: failed\n {traceback.format_exc()}"
- with open(logfile, "a") as file:
- file.write(status + "\n")
- continue
-
-
-if __name__ == "__main__":
- batch_convert()
diff --git a/src/lerobot/datasets/v2/convert_dataset_v1_to_v2.py b/src/lerobot/datasets/v2/convert_dataset_v1_to_v2.py
deleted file mode 100644
index cddfc4c18..000000000
--- a/src/lerobot/datasets/v2/convert_dataset_v1_to_v2.py
+++ /dev/null
@@ -1,687 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
-2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
-for each of the task performed in the dataset. This will allow to easily train models with task-conditioning.
-
-We support 3 different scenarios for these tasks (see instructions below):
- 1. Single task dataset: all episodes of your dataset have the same single task.
- 2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
- one episode to the next.
- 3. Multi task episodes: episodes of your dataset may each contain several different tasks.
-
-
-Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
-'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
-recorded with. For now, only Aloha/Koch type robots are supported with this option.
-
-
-# 1. Single task dataset
-If your dataset contains a single task, you can simply provide it directly via the CLI with the
-'--single-task' option.
-
-Examples:
-
-```bash
-python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
- --repo-id lerobot/aloha_sim_insertion_human_image \
- --single-task "Insert the peg into the socket." \
- --robot-config lerobot/configs/robot/aloha.yaml \
- --local-dir data
-```
-
-```bash
-python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
- --repo-id aliberts/koch_tutorial \
- --single-task "Pick the Lego block and drop it in the box on the right." \
- --robot-config lerobot/configs/robot/koch.yaml \
- --local-dir data
-```
-
-
-# 2. Single task episodes
-If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
-
-- If your dataset already contains a language instruction column in its parquet file, you can simply provide
- this column's name with the '--tasks-col' arg.
-
- Example:
-
- ```bash
- python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
- --repo-id lerobot/stanford_kuka_multimodal_dataset \
- --tasks-col "language_instruction" \
- --local-dir data
- ```
-
-- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
- '--tasks-path' arg. This file should have the following structure where keys correspond to each
- episode_index in the dataset, and values are the language instruction for that episode.
-
- Example:
-
- ```json
- {
- "0": "Do something",
- "1": "Do something else",
- "2": "Do something",
- "3": "Go there",
- ...
- }
- ```
-
-# 3. Multi task episodes
-If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
-parquet file, and you must provide this column's name with the '--tasks-col' arg.
-
-Example:
-
-```bash
-python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \
- --repo-id lerobot/stanford_kuka_multimodal_dataset \
- --tasks-col "language_instruction" \
- --local-dir data
-```
-"""
-
-import argparse
-import contextlib
-import filecmp
-import json
-import logging
-import math
-import shutil
-import subprocess
-import tempfile
-from pathlib import Path
-
-import datasets
-import pyarrow.compute as pc
-import pyarrow.parquet as pq
-import torch
-from datasets import Dataset
-from huggingface_hub import HfApi
-from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
-from safetensors.torch import load_file
-
-from lerobot.datasets.utils import (
- DEFAULT_CHUNK_SIZE,
- DEFAULT_PARQUET_PATH,
- DEFAULT_VIDEO_PATH,
- EPISODES_PATH,
- INFO_PATH,
- STATS_PATH,
- TASKS_PATH,
- create_branch,
- create_lerobot_dataset_card,
- flatten_dict,
- get_safe_version,
- load_json,
- unflatten_dict,
- write_json,
- write_jsonlines,
-)
-from lerobot.datasets.video_utils import (
- VideoFrame, # noqa: F401
- get_image_pixel_channels,
- get_video_info,
-)
-from lerobot.robots import RobotConfig
-
-V16 = "v1.6"
-V20 = "v2.0"
-
-GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
-V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
-V1_INFO_PATH = "meta_data/info.json"
-V1_STATS_PATH = "meta_data/stats.safetensors"
-
-
-def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]:
- if robot_cfg.type in ["aloha", "koch"]:
- state_names = [
- f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor
- for arm in robot_cfg.follower_arms
- for motor in robot_cfg.follower_arms[arm].motors
- ]
- action_names = [
- # f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
- f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor
- for arm in robot_cfg.leader_arms
- for motor in robot_cfg.leader_arms[arm].motors
- ]
- # elif robot_cfg["robot_type"] == "stretch3": TODO
- else:
- raise NotImplementedError(
- "Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
- )
-
- return {
- "robot_type": robot_cfg.type,
- "names": {
- "observation.state": state_names,
- "observation.effort": state_names,
- "action": action_names,
- },
- }
-
-
-def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
- safetensor_path = v1_dir / V1_STATS_PATH
- stats = load_file(safetensor_path)
- serialized_stats = {key: value.tolist() for key, value in stats.items()}
- serialized_stats = unflatten_dict(serialized_stats)
-
- json_path = v2_dir / STATS_PATH
- json_path.parent.mkdir(exist_ok=True, parents=True)
- with open(json_path, "w") as f:
- json.dump(serialized_stats, f, indent=4)
-
- # Sanity check
- with open(json_path) as f:
- stats_json = json.load(f)
-
- stats_json = flatten_dict(stats_json)
- stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
- for key in stats:
- torch.testing.assert_close(stats_json[key], stats[key])
-
-
-def get_features_from_hf_dataset(
- dataset: Dataset, robot_config: RobotConfig | None = None
-) -> dict[str, list]:
- robot_config = parse_robot_config(robot_config)
- features = {}
- for key, ft in dataset.features.items():
- if isinstance(ft, datasets.Value):
- dtype = ft.dtype
- shape = (1,)
- names = None
- if isinstance(ft, datasets.Sequence):
- assert isinstance(ft.feature, datasets.Value)
- dtype = ft.feature.dtype
- shape = (ft.length,)
- motor_names = (
- robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
- )
- assert len(motor_names) == shape[0]
- names = {"motors": motor_names}
- elif isinstance(ft, datasets.Image):
- dtype = "image"
- image = dataset[0][key] # Assuming first row
- channels = get_image_pixel_channels(image)
- shape = (image.height, image.width, channels)
- names = ["height", "width", "channels"]
- elif ft._type == "VideoFrame":
- dtype = "video"
- shape = None # Add shape later
- names = ["height", "width", "channels"]
-
- features[key] = {
- "dtype": dtype,
- "shape": shape,
- "names": names,
- }
-
- return features
-
-
-def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
- df = dataset.to_pandas()
- tasks = list(set(tasks_by_episodes.values()))
- tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
- episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
- df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
-
- features = dataset.features
- features["task_index"] = datasets.Value(dtype="int64")
- dataset = Dataset.from_pandas(df, features=features, split="train")
- return dataset, tasks
-
-
-def add_task_index_from_tasks_col(
- dataset: Dataset, tasks_col: str
-) -> tuple[Dataset, dict[str, list[str]], list[str]]:
- df = dataset.to_pandas()
-
- # HACK: This is to clean some of the instructions in our version of Open X datasets
- prefix_to_clean = "tf.Tensor(b'"
- suffix_to_clean = "', shape=(), dtype=string)"
- df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
-
- # Create task_index col
- tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
- tasks = df[tasks_col].unique().tolist()
- tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
- df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
-
- # Build the dataset back from df
- features = dataset.features
- features["task_index"] = datasets.Value(dtype="int64")
- dataset = Dataset.from_pandas(df, features=features, split="train")
- dataset = dataset.remove_columns(tasks_col)
-
- return dataset, tasks, tasks_by_episode
-
-
-def split_parquet_by_episodes(
- dataset: Dataset,
- total_episodes: int,
- total_chunks: int,
- output_dir: Path,
-) -> list:
- table = dataset.data.table
- episode_lengths = []
- for ep_chunk in range(total_chunks):
- ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
- ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
- chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
- (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
- for ep_idx in range(ep_chunk_start, ep_chunk_end):
- ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
- episode_lengths.insert(ep_idx, len(ep_table))
- output_file = output_dir / DEFAULT_PARQUET_PATH.format(
- episode_chunk=ep_chunk, episode_index=ep_idx
- )
- pq.write_table(ep_table, output_file)
-
- return episode_lengths
-
-
-def move_videos(
- repo_id: str,
- video_keys: list[str],
- total_episodes: int,
- total_chunks: int,
- work_dir: Path,
- clean_gittatributes: Path,
- branch: str = "main",
-) -> None:
- """
- HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
- commands to fetch git lfs video files references to move them into subdirectories without having to
- actually download them.
- """
- _lfs_clone(repo_id, work_dir, branch)
-
- videos_moved = False
- video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
- if len(video_files) == 0:
- video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
- videos_moved = True # Videos have already been moved
-
- assert len(video_files) == total_episodes * len(video_keys)
-
- lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
-
- current_gittatributes = work_dir / ".gitattributes"
- if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
- fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
-
- if lfs_untracked_videos:
- fix_lfs_video_files_tracking(work_dir, video_files)
-
- if videos_moved:
- return
-
- video_dirs = sorted(work_dir.glob("videos*/"))
- for ep_chunk in range(total_chunks):
- ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
- ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
- for vid_key in video_keys:
- chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
- episode_chunk=ep_chunk, video_key=vid_key
- )
- (work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
-
- for ep_idx in range(ep_chunk_start, ep_chunk_end):
- target_path = DEFAULT_VIDEO_PATH.format(
- episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
- )
- video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
- if len(video_dirs) == 1:
- video_path = video_dirs[0] / video_file
- else:
- for dir in video_dirs:
- if (dir / video_file).is_file():
- video_path = dir / video_file
- break
-
- video_path.rename(work_dir / target_path)
-
- commit_message = "Move video files into chunk subdirectories"
- subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
- subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
- subprocess.run(["git", "push"], cwd=work_dir, check=True)
-
-
-def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
- """
- HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
- there's no other option than to download the actual files and reupload them with lfs tracking.
- """
- for i in range(0, len(lfs_untracked_videos), 100):
- files = lfs_untracked_videos[i : i + 100]
- try:
- subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
- except subprocess.CalledProcessError as e:
- print("git rm --cached ERROR:")
- print(e.stderr)
- subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
-
- commit_message = "Track video files with git lfs"
- subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
- subprocess.run(["git", "push"], cwd=work_dir, check=True)
-
-
-def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
- shutil.copyfile(clean_gittatributes, current_gittatributes)
- subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
- subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
- subprocess.run(["git", "push"], cwd=work_dir, check=True)
-
-
-def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
- subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
- repo_url = f"https://huggingface.co/datasets/{repo_id}"
- env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
- subprocess.run(
- ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
- check=True,
- env=env,
- )
-
-
-def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
- lfs_tracked_files = subprocess.run(
- ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
- )
- lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
- return [f for f in video_files if f not in lfs_tracked_files]
-
-
-def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
- # Assumes first episode
- video_files = [
- DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
- for vid_key in video_keys
- ]
- hub_api = HfApi()
- hub_api.snapshot_download(
- repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
- )
- videos_info_dict = {}
- for vid_key, vid_path in zip(video_keys, video_files, strict=True):
- videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
-
- return videos_info_dict
-
-
-def convert_dataset(
- repo_id: str,
- local_dir: Path,
- single_task: str | None = None,
- tasks_path: Path | None = None,
- tasks_col: Path | None = None,
- robot_config: RobotConfig | None = None,
- test_branch: str | None = None,
- **card_kwargs,
-):
- v1 = get_safe_version(repo_id, V16)
- v1x_dir = local_dir / V16 / repo_id
- v20_dir = local_dir / V20 / repo_id
- v1x_dir.mkdir(parents=True, exist_ok=True)
- v20_dir.mkdir(parents=True, exist_ok=True)
-
- hub_api = HfApi()
- hub_api.snapshot_download(
- repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
- )
- branch = "main"
- if test_branch:
- branch = test_branch
- create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
-
- metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
- dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
- features = get_features_from_hf_dataset(dataset, robot_config)
- video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
-
- if single_task and "language_instruction" in dataset.column_names:
- logging.warning(
- "'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
- )
- single_task = None
- tasks_col = "language_instruction"
-
- # Episodes & chunks
- episode_indices = sorted(dataset.unique("episode_index"))
- total_episodes = len(episode_indices)
- assert episode_indices == list(range(total_episodes))
- total_videos = total_episodes * len(video_keys)
- total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
- if total_episodes % DEFAULT_CHUNK_SIZE != 0:
- total_chunks += 1
-
- # Tasks
- if single_task:
- tasks_by_episodes = dict.fromkeys(episode_indices, single_task)
- dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
- tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
- elif tasks_path:
- tasks_by_episodes = load_json(tasks_path)
- tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
- dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
- tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
- elif tasks_col:
- dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
- else:
- raise ValueError
-
- assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
- tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
- write_jsonlines(tasks, v20_dir / TASKS_PATH)
- features["task_index"] = {
- "dtype": "int64",
- "shape": (1,),
- "names": None,
- }
-
- # Videos
- if video_keys:
- assert metadata_v1.get("video", False)
- dataset = dataset.remove_columns(video_keys)
- clean_gitattr = Path(
- hub_api.hf_hub_download(
- repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
- )
- ).absolute()
- with tempfile.TemporaryDirectory() as tmp_video_dir:
- move_videos(
- repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
- )
- videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
- for key in video_keys:
- features[key]["shape"] = (
- videos_info[key].pop("video.height"),
- videos_info[key].pop("video.width"),
- videos_info[key].pop("video.channels"),
- )
- features[key]["video_info"] = videos_info[key]
- assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
- if "encoding" in metadata_v1:
- assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
- else:
- assert metadata_v1.get("video", 0) == 0
- videos_info = None
-
- # Split data into 1 parquet file by episode
- episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
-
- if robot_config is not None:
- robot_type = robot_config.type
- repo_tags = [robot_type]
- else:
- robot_type = "unknown"
- repo_tags = None
-
- # Episodes
- episodes = [
- {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
- for ep_idx in episode_indices
- ]
- write_jsonlines(episodes, v20_dir / EPISODES_PATH)
-
- # Assemble metadata v2.0
- metadata_v2_0 = {
- "codebase_version": V20,
- "robot_type": robot_type,
- "total_episodes": total_episodes,
- "total_frames": len(dataset),
- "total_tasks": len(tasks),
- "total_videos": total_videos,
- "total_chunks": total_chunks,
- "chunks_size": DEFAULT_CHUNK_SIZE,
- "fps": metadata_v1["fps"],
- "splits": {"train": f"0:{total_episodes}"},
- "data_path": DEFAULT_PARQUET_PATH,
- "video_path": DEFAULT_VIDEO_PATH if video_keys else None,
- "features": features,
- }
- write_json(metadata_v2_0, v20_dir / INFO_PATH)
- convert_stats_to_json(v1x_dir, v20_dir)
- card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
-
- with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
- hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
-
- with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
- hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
-
- with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
- hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
-
- hub_api.upload_folder(
- repo_id=repo_id,
- path_in_repo="data",
- folder_path=v20_dir / "data",
- repo_type="dataset",
- revision=branch,
- )
- hub_api.upload_folder(
- repo_id=repo_id,
- path_in_repo="meta",
- folder_path=v20_dir / "meta",
- repo_type="dataset",
- revision=branch,
- )
-
- card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
-
- if not test_branch:
- create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
-
-
-def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
- if robot_type == "aloha":
- raise NotImplementedError # TODO
-
- elif robot_type == "koch_follower":
- from lerobot.robots.koch_follower import KochFollowerConfig
-
- return KochFollowerConfig(**kwargs)
- elif robot_type == "so100_follower":
- from lerobot.robots.so100_follower import SO100FollowerConfig
-
- return SO100FollowerConfig(**kwargs)
- elif robot_type == "stretch":
- from lerobot.robots.stretch3 import Stretch3RobotConfig
-
- return Stretch3RobotConfig(**kwargs)
- elif robot_type == "lekiwi":
- from lerobot.robots.lekiwi import LeKiwiConfig
-
- return LeKiwiConfig(**kwargs)
- else:
- raise ValueError(f"Robot type '{robot_type}' is not available.")
-
-
-def main():
- parser = argparse.ArgumentParser()
- task_args = parser.add_mutually_exclusive_group(required=True)
-
- parser.add_argument(
- "--repo-id",
- type=str,
- required=True,
- help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
- )
- task_args.add_argument(
- "--single-task",
- type=str,
- help="A short but accurate description of the single task performed in the dataset.",
- )
- task_args.add_argument(
- "--tasks-col",
- type=str,
- help="The name of the column containing language instructions",
- )
- task_args.add_argument(
- "--tasks-path",
- type=Path,
- help="The path to a .json file containing one language instruction for each episode_index",
- )
- parser.add_argument(
- "--robot",
- type=str,
- default=None,
- help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)",
- )
- parser.add_argument(
- "--local-dir",
- type=Path,
- default=None,
- help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
- )
- parser.add_argument(
- "--license",
- type=str,
- default="apache-2.0",
- help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
- )
- parser.add_argument(
- "--test-branch",
- type=str,
- default=None,
- help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
- )
-
- args = parser.parse_args()
- if not args.local_dir:
- args.local_dir = Path("/tmp/lerobot_dataset_v2")
-
- if args.robot is not None:
- robot_config = make_robot_config(args.robot)
-
- del args.robot
-
- convert_dataset(**vars(args), robot_config=robot_config)
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/lerobot/datasets/v21/_remove_language_instruction.py b/src/lerobot/datasets/v21/_remove_language_instruction.py
deleted file mode 100644
index 1f1cb1855..000000000
--- a/src/lerobot/datasets/v21/_remove_language_instruction.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import traceback
-from pathlib import Path
-
-from datasets import get_dataset_config_info
-from huggingface_hub import HfApi
-
-from lerobot import available_datasets
-from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
-from lerobot.datasets.utils import INFO_PATH, write_info
-from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
-
-LOCAL_DIR = Path("data/")
-
-hub_api = HfApi()
-
-
-def fix_dataset(repo_id: str) -> str:
- if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
- return f"{repo_id}: skipped (not in {V20})."
-
- dataset_info = get_dataset_config_info(repo_id, "default")
- with SuppressWarnings():
- lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
-
- meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
- parquet_features = set(dataset_info.features)
-
- diff_parquet_meta = parquet_features - meta_features
- diff_meta_parquet = meta_features - parquet_features
-
- if diff_parquet_meta:
- raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
-
- if not diff_meta_parquet:
- return f"{repo_id}: skipped (no diff)"
-
- if diff_meta_parquet:
- logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
- assert diff_meta_parquet == {"language_instruction"}
- lerobot_metadata.features.pop("language_instruction")
- write_info(lerobot_metadata.info, lerobot_metadata.root)
- commit_info = hub_api.upload_file(
- path_or_fileobj=lerobot_metadata.root / INFO_PATH,
- path_in_repo=INFO_PATH,
- repo_id=repo_id,
- repo_type="dataset",
- revision=V20,
- commit_message="Remove 'language_instruction'",
- create_pr=True,
- )
- return f"{repo_id}: success - PR: {commit_info.pr_url}"
-
-
-def batch_fix():
- status = {}
- LOCAL_DIR.mkdir(parents=True, exist_ok=True)
- logfile = LOCAL_DIR / "fix_features_v20.txt"
- for num, repo_id in enumerate(available_datasets):
- print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
- print("---------------------------------------------------------")
- try:
- status = fix_dataset(repo_id)
- except Exception:
- status = f"{repo_id}: failed\n {traceback.format_exc()}"
-
- logging.info(status)
- with open(logfile, "a") as file:
- file.write(status + "\n")
-
-
-if __name__ == "__main__":
- batch_fix()
diff --git a/src/lerobot/datasets/v21/batch_convert_dataset_v20_to_v21.py b/src/lerobot/datasets/v21/batch_convert_dataset_v20_to_v21.py
deleted file mode 100644
index b4f1c36c4..000000000
--- a/src/lerobot/datasets/v21/batch_convert_dataset_v20_to_v21.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
-"""
-
-import traceback
-from pathlib import Path
-
-from huggingface_hub import HfApi
-
-from lerobot import available_datasets
-from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
-
-LOCAL_DIR = Path("data/")
-
-
-def batch_convert():
- status = {}
- LOCAL_DIR.mkdir(parents=True, exist_ok=True)
- logfile = LOCAL_DIR / "conversion_log_v21.txt"
- hub_api = HfApi()
- for num, repo_id in enumerate(available_datasets):
- print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
- print("---------------------------------------------------------")
- try:
- if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
- status = f"{repo_id}: success (already in {V21})."
- else:
- convert_dataset(repo_id)
- status = f"{repo_id}: success."
- except Exception:
- status = f"{repo_id}: failed\n {traceback.format_exc()}"
-
- with open(logfile, "a") as file:
- file.write(status + "\n")
-
-
-if __name__ == "__main__":
- batch_convert()
diff --git a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py
deleted file mode 100644
index 4ebc1086a..000000000
--- a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
-2.1. It will:
-
-- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
-- Check consistency between these new stats and the old ones.
-- Remove the deprecated `stats.json`.
-- Update codebase_version in `info.json`.
-- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
-
-Usage:
-
-```bash
-python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \
- --repo-id=aliberts/koch_tutorial
-```
-
-"""
-
-import argparse
-import logging
-
-from huggingface_hub import HfApi
-
-from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
-from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
-from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
-
-V20 = "v2.0"
-V21 = "v2.1"
-
-
-class SuppressWarnings:
- def __enter__(self):
- self.previous_level = logging.getLogger().getEffectiveLevel()
- logging.getLogger().setLevel(logging.ERROR)
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- logging.getLogger().setLevel(self.previous_level)
-
-
-def convert_dataset(
- repo_id: str,
- branch: str | None = None,
- num_workers: int = 4,
-):
- with SuppressWarnings():
- dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
-
- if (dataset.root / EPISODES_STATS_PATH).is_file():
- (dataset.root / EPISODES_STATS_PATH).unlink()
-
- convert_stats(dataset, num_workers=num_workers)
- ref_stats = load_stats(dataset.root)
- check_aggregate_stats(dataset, ref_stats)
-
- dataset.meta.info["codebase_version"] = CODEBASE_VERSION
- write_info(dataset.meta.info, dataset.root)
-
- dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
-
- # delete old stats.json file
- if (dataset.root / STATS_PATH).is_file:
- (dataset.root / STATS_PATH).unlink()
-
- hub_api = HfApi()
- if hub_api.file_exists(
- repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
- ):
- hub_api.delete_file(
- path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
- )
-
- hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--repo-id",
- type=str,
- required=True,
- help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
- "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
- )
- parser.add_argument(
- "--branch",
- type=str,
- default=None,
- help="Repo branch to push your dataset. Defaults to the main branch.",
- )
- parser.add_argument(
- "--num-workers",
- type=int,
- default=4,
- help="Number of workers for parallelizing stats compute. Defaults to 4.",
- )
-
- args = parser.parse_args()
- convert_dataset(**vars(args))
diff --git a/src/lerobot/datasets/v21/convert_stats.py b/src/lerobot/datasets/v21/convert_stats.py
deleted file mode 100644
index 462781c15..000000000
--- a/src/lerobot/datasets/v21/convert_stats.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from concurrent.futures import ThreadPoolExecutor, as_completed
-
-import numpy as np
-from tqdm import tqdm
-
-from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.utils import write_episode_stats
-
-
-def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
- ep_len = dataset.meta.episodes[episode_index]["length"]
- sampled_indices = sample_indices(ep_len)
- query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
- video_frames = dataset._query_videos(query_timestamps, episode_index)
- return video_frames[ft_key].numpy()
-
-
-def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
- ep_start_idx = dataset.episode_data_index["from"][ep_idx]
- ep_end_idx = dataset.episode_data_index["to"][ep_idx]
- ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
-
- ep_stats = {}
- for key, ft in dataset.features.items():
- if ft["dtype"] == "video":
- # We sample only for videos
- ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
- else:
- ep_ft_data = np.array(ep_data[key])
-
- axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
- keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
- ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
-
- if ft["dtype"] in ["image", "video"]: # remove batch dim
- ep_stats[key] = {
- k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
- }
-
- dataset.meta.episodes_stats[ep_idx] = ep_stats
-
-
-def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
- assert dataset.episodes is None
- print("Computing episodes stats")
- total_episodes = dataset.meta.total_episodes
- if num_workers > 0:
- with ThreadPoolExecutor(max_workers=num_workers) as executor:
- futures = {
- executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
- for ep_idx in range(total_episodes)
- }
- for future in tqdm(as_completed(futures), total=total_episodes):
- future.result()
- else:
- for ep_idx in tqdm(range(total_episodes)):
- convert_episode_stats(dataset, ep_idx)
-
- for ep_idx in tqdm(range(total_episodes)):
- write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
-
-
-def check_aggregate_stats(
- dataset: LeRobotDataset,
- reference_stats: dict[str, dict[str, np.ndarray]],
- video_rtol_atol: tuple[float] = (1e-2, 1e-2),
- default_rtol_atol: tuple[float] = (5e-6, 6e-5),
-):
- """Verifies that the aggregated stats from episodes_stats are close to reference stats."""
- agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
- for key, ft in dataset.features.items():
- # These values might need some fine-tuning
- if ft["dtype"] == "video":
- # to account for image sub-sampling
- rtol, atol = video_rtol_atol
- else:
- rtol, atol = default_rtol_atol
-
- for stat, val in agg_stats[key].items():
- if key in reference_stats and stat in reference_stats[key]:
- err_msg = f"feature='{key}' stats='{stat}'"
- np.testing.assert_allclose(
- val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
- )
diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py
new file mode 100644
index 000000000..900a43a4f
--- /dev/null
+++ b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py
@@ -0,0 +1,260 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script augments existing LeRobot datasets with quantile statistics.
+
+Most datasets created before the quantile feature was added do not contain
+quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
+
+1. Loads an existing LeRobot dataset in v3.0 format
+2. Checks if it already contains quantile statistics
+3. If missing, computes quantile statistics for all features
+4. Updates the dataset metadata with the new quantile statistics
+
+Usage:
+
+```bash
+python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
+ --repo-id=lerobot/pusht \
+```
+"""
+
+import argparse
+import concurrent.futures
+import logging
+from pathlib import Path
+
+import numpy as np
+import torch
+from huggingface_hub import HfApi
+from requests import HTTPError
+from tqdm import tqdm
+
+from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
+from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
+from lerobot.datasets.utils import write_stats
+from lerobot.utils.utils import init_logging
+
+
+def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool:
+ """Check if dataset statistics already contain quantile information.
+
+ Args:
+ stats: Dataset statistics dictionary
+
+ Returns:
+ True if quantile statistics are present, False otherwise
+ """
+ if quantile_list_keys is None:
+ quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES]
+
+ if stats is None:
+ return False
+
+ for feature_stats in stats.values():
+ if any(q_key in feature_stats for q_key in quantile_list_keys):
+ return True
+
+ return False
+
+
+def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
+ """Process a single episode and return its statistics.
+
+ Args:
+ dataset: The LeRobot dataset
+ episode_idx: Index of the episode to process
+
+ Returns:
+ Dictionary containing episode statistics
+ """
+ logging.info(f"Computing stats for episode {episode_idx}")
+
+ start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
+ end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
+
+ collected_data: dict[str, list] = {}
+ for idx in range(start_idx, end_idx):
+ item = dataset[idx]
+ for key, value in item.items():
+ if key not in dataset.features:
+ continue
+
+ if key not in collected_data:
+ collected_data[key] = []
+ collected_data[key].append(value)
+
+ ep_stats = {}
+ for key, data_list in collected_data.items():
+ if dataset.features[key]["dtype"] == "string":
+ continue
+
+ data = torch.stack(data_list).cpu().numpy()
+ if dataset.features[key]["dtype"] in ["image", "video"]:
+ if data.dtype == np.uint8:
+ data = data.astype(np.float32) / 255.0
+
+ axes_to_reduce = (0, 2, 3)
+ keepdims = True
+ else:
+ axes_to_reduce = 0
+ keepdims = data.ndim == 1
+
+ ep_stats[key] = get_feature_stats(
+ data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
+ )
+
+ if dataset.features[key]["dtype"] in ["image", "video"]:
+ ep_stats[key] = {
+ k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
+ }
+
+ return ep_stats
+
+
+def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
+ """Compute quantile statistics for all episodes in the dataset.
+
+ Args:
+ dataset: The LeRobot dataset to compute statistics for
+
+ Returns:
+ Dictionary containing aggregated statistics with quantiles
+
+ Note:
+ Video decoding operations are not thread-safe, so we process episodes sequentially
+ when video keys are present. For datasets without videos, we use parallel processing
+ with ThreadPoolExecutor for better performance.
+ """
+ logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
+
+ episode_stats_list = []
+ has_videos = len(dataset.meta.video_keys) > 0
+
+ if has_videos:
+ logging.info("Dataset contains video keys - using sequential processing for thread safety")
+ for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"):
+ ep_stats = process_single_episode(dataset, episode_idx)
+ episode_stats_list.append(ep_stats)
+ else:
+ logging.info("Dataset has no video keys - using parallel processing for better performance")
+ max_workers = min(dataset.num_episodes, 16)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+ future_to_episode = {
+ executor.submit(process_single_episode, dataset, episode_idx): episode_idx
+ for episode_idx in range(dataset.num_episodes)
+ }
+
+ episode_results = {}
+ with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
+ for future in concurrent.futures.as_completed(future_to_episode):
+ episode_idx = future_to_episode[future]
+ ep_stats = future.result()
+ episode_results[episode_idx] = ep_stats
+ pbar.update(1)
+
+ for episode_idx in range(dataset.num_episodes):
+ if episode_idx in episode_results:
+ episode_stats_list.append(episode_results[episode_idx])
+
+ if not episode_stats_list:
+ raise ValueError("No episode data found for computing statistics")
+
+ logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes")
+ return aggregate_stats(episode_stats_list)
+
+
+def augment_dataset_with_quantile_stats(
+ repo_id: str,
+ root: str | Path | None = None,
+ overwrite: bool = False,
+) -> None:
+ """Augment a dataset with quantile statistics if they are missing.
+
+ Args:
+ repo_id: Repository ID of the dataset
+ root: Local root directory for the dataset
+ overwrite: Overwrite existing quantile statistics if they already exist
+ """
+ logging.info(f"Loading dataset: {repo_id}")
+ dataset = LeRobotDataset(
+ repo_id=repo_id,
+ root=root,
+ )
+
+ if not overwrite and has_quantile_stats(dataset.meta.stats):
+ logging.info("Dataset already contains quantile statistics. No action needed.")
+ return
+
+ logging.info("Dataset does not contain quantile statistics. Computing them now...")
+
+ new_stats = compute_quantile_stats_for_dataset(dataset)
+
+ logging.info("Updating dataset metadata with new quantile statistics")
+ dataset.meta.stats = new_stats
+
+ write_stats(new_stats, dataset.meta.root)
+
+ logging.info("Successfully updated dataset with quantile statistics")
+ dataset.push_to_hub()
+
+ hub_api = HfApi()
+ try:
+ hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
+ except HTTPError as e:
+ logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
+ pass
+ hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset")
+
+
+def main():
+ """Main function to run the augmentation script."""
+ parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics")
+
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ required=True,
+ help="Repository ID of the dataset (e.g., 'lerobot/pusht')",
+ )
+
+ parser.add_argument(
+ "--root",
+ type=str,
+ help="Local root directory for the dataset",
+ )
+ parser.add_argument(
+ "--overwrite",
+ action="store_true",
+ help="Overwrite existing quantile statistics if they already exist",
+ )
+
+ args = parser.parse_args()
+ root = Path(args.root) if args.root else None
+
+ init_logging()
+
+ augment_dataset_with_quantile_stats(
+ repo_id=args.repo_id,
+ root=root,
+ overwrite=args.overwrite,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
new file mode 100644
index 000000000..42ab2f642
--- /dev/null
+++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py
@@ -0,0 +1,571 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.1 to
+3.0. It will:
+
+- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
+- Check consistency between these new stats and the old ones.
+- Remove the deprecated `stats.json`.
+- Update codebase_version in `info.json`.
+- Push this new version to the hub on the 'main' branch and tags it with "v3.0".
+
+Usage:
+
+Convert a dataset from the hub:
+```bash
+python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
+ --repo-id=lerobot/pusht
+```
+
+Convert a local dataset (works in place):
+```bash
+python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
+ --repo-id=lerobot/pusht \
+ --root=/path/to/local/dataset/directory
+ --push-to-hub=false
+```
+
+"""
+
+import argparse
+import logging
+import shutil
+from pathlib import Path
+from typing import Any
+
+import jsonlines
+import pandas as pd
+import pyarrow as pa
+import tqdm
+from datasets import Dataset, Features, Image
+from huggingface_hub import HfApi, snapshot_download
+from requests import HTTPError
+
+from lerobot.datasets.compute_stats import aggregate_stats
+from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
+from lerobot.datasets.utils import (
+ DEFAULT_CHUNK_SIZE,
+ DEFAULT_DATA_FILE_SIZE_IN_MB,
+ DEFAULT_DATA_PATH,
+ DEFAULT_VIDEO_FILE_SIZE_IN_MB,
+ DEFAULT_VIDEO_PATH,
+ LEGACY_EPISODES_PATH,
+ LEGACY_EPISODES_STATS_PATH,
+ LEGACY_TASKS_PATH,
+ cast_stats_to_numpy,
+ flatten_dict,
+ get_parquet_file_size_in_mb,
+ get_parquet_num_frames,
+ get_video_size_in_mb,
+ load_info,
+ update_chunk_file_indices,
+ write_episodes,
+ write_info,
+ write_stats,
+ write_tasks,
+)
+from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
+from lerobot.utils.constants import HF_LEROBOT_HOME
+from lerobot.utils.utils import init_logging
+
+V21 = "v2.1"
+V30 = "v3.0"
+
+"""
+-------------------------
+OLD
+data/chunk-000/episode_000000.parquet
+
+NEW
+data/chunk-000/file_000.parquet
+-------------------------
+OLD
+videos/chunk-000/CAMERA/episode_000000.mp4
+
+NEW
+videos/chunk-000/file_000.mp4
+-------------------------
+OLD
+episodes.jsonl
+{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
+
+NEW
+meta/episodes/chunk-000/episodes_000.parquet
+episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
+-------------------------
+OLD
+tasks.jsonl
+{"task_index": 1, "task": "Put the blue block in the green bowl"}
+
+NEW
+meta/tasks/chunk-000/file_000.parquet
+task_index | task
+-------------------------
+OLD
+episodes_stats.jsonl
+
+NEW
+meta/episodes_stats/chunk-000/file_000.parquet
+episode_index | mean | std | min | max
+-------------------------
+UPDATE
+meta/info.json
+-------------------------
+"""
+
+
+def load_jsonlines(fpath: Path) -> list[Any]:
+ with jsonlines.open(fpath, "r") as reader:
+ return list(reader)
+
+
+def legacy_load_episodes(local_dir: Path) -> dict:
+ episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
+ return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
+
+
+def legacy_load_episodes_stats(local_dir: Path) -> dict:
+ episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
+ return {
+ item["episode_index"]: cast_stats_to_numpy(item["stats"])
+ for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
+ }
+
+
+def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
+ tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
+ tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
+ task_to_task_index = {task: task_index for task_index, task in tasks.items()}
+ return tasks, task_to_task_index
+
+
+def validate_local_dataset_version(local_path: Path) -> None:
+ """Validate that the local dataset has the expected v2.1 version."""
+ info = load_info(local_path)
+ dataset_version = info.get("codebase_version", "unknown")
+ if dataset_version != V21:
+ raise ValueError(
+ f"Local dataset has codebase version '{dataset_version}', expected '{V21}'. "
+ f"This script is specifically for converting v2.1 datasets to v3.0."
+ )
+
+
+def convert_tasks(root, new_root):
+ logging.info(f"Converting tasks from {root} to {new_root}")
+ tasks, _ = legacy_load_tasks(root)
+ task_indices = tasks.keys()
+ task_strings = tasks.values()
+ df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
+ write_tasks(df_tasks, new_root)
+
+
+def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
+ # TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
+ dataframes = [pd.read_parquet(file) for file in paths_to_cat]
+ # Concatenate all DataFrames along rows
+ concatenated_df = pd.concat(dataframes, ignore_index=True)
+
+ path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
+ path.parent.mkdir(parents=True, exist_ok=True)
+
+ if len(image_keys) > 0:
+ schema = pa.Schema.from_pandas(concatenated_df)
+ features = Features.from_arrow_schema(schema)
+ for key in image_keys:
+ features[key] = Image()
+ schema = features.arrow_schema
+ else:
+ schema = None
+
+ concatenated_df.to_parquet(path, index=False, schema=schema)
+
+
+def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
+ data_dir = root / "data"
+ ep_paths = sorted(data_dir.glob("*/*.parquet"))
+
+ image_keys = get_image_keys(root)
+
+ ep_idx = 0
+ chunk_idx = 0
+ file_idx = 0
+ size_in_mb = 0
+ num_frames = 0
+ paths_to_cat = []
+ episodes_metadata = []
+
+ logging.info(f"Converting data files from {len(ep_paths)} episodes")
+
+ for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
+ ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
+ ep_num_frames = get_parquet_num_frames(ep_path)
+ ep_metadata = {
+ "episode_index": ep_idx,
+ "data/chunk_index": chunk_idx,
+ "data/file_index": file_idx,
+ "dataset_from_index": num_frames,
+ "dataset_to_index": num_frames + ep_num_frames,
+ }
+ size_in_mb += ep_size_in_mb
+ num_frames += ep_num_frames
+ episodes_metadata.append(ep_metadata)
+ ep_idx += 1
+
+ if size_in_mb < data_file_size_in_mb:
+ paths_to_cat.append(ep_path)
+ continue
+
+ if paths_to_cat:
+ concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
+
+ # Reset for the next file
+ size_in_mb = ep_size_in_mb
+ paths_to_cat = [ep_path]
+
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
+
+ # Write remaining data if any
+ if paths_to_cat:
+ concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
+
+ return episodes_metadata
+
+
+def get_video_keys(root):
+ info = load_info(root)
+ features = info["features"]
+ video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
+ return video_keys
+
+
+def get_image_keys(root):
+ info = load_info(root)
+ features = info["features"]
+ image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
+ return image_keys
+
+
+def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
+ logging.info(f"Converting videos from {root} to {new_root}")
+
+ video_keys = get_video_keys(root)
+ if len(video_keys) == 0:
+ return None
+
+ video_keys = sorted(video_keys)
+
+ eps_metadata_per_cam = []
+ for camera in video_keys:
+ eps_metadata = convert_videos_of_camera(root, new_root, camera, video_file_size_in_mb)
+ eps_metadata_per_cam.append(eps_metadata)
+
+ num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
+ if len(set(num_eps_per_cam)) != 1:
+ raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
+
+ episods_metadata = []
+ num_cameras = len(video_keys)
+ num_episodes = num_eps_per_cam[0]
+ for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert videos"):
+ # Sanity check
+ ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
+ ep_ids += [ep_idx]
+ if len(set(ep_ids)) != 1:
+ raise ValueError(f"All episode indices need to match ({ep_ids}).")
+
+ ep_dict = {}
+ for cam_idx in range(num_cameras):
+ ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
+ episods_metadata.append(ep_dict)
+
+ return episods_metadata
+
+
+def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_file_size_in_mb: int):
+ # Access old paths to mp4
+ videos_dir = root / "videos"
+ ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
+
+ ep_idx = 0
+ chunk_idx = 0
+ file_idx = 0
+ size_in_mb = 0
+ duration_in_s = 0.0
+ paths_to_cat = []
+ episodes_metadata = []
+
+ for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
+ ep_size_in_mb = get_video_size_in_mb(ep_path)
+ ep_duration_in_s = get_video_duration_in_s(ep_path)
+
+ # Check if adding this episode would exceed the limit
+ if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
+ # Size limit would be exceeded, save current accumulation WITHOUT this episode
+ concatenate_video_files(
+ paths_to_cat,
+ new_root
+ / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
+ )
+
+ # Update episodes metadata for the file we just saved
+ for i, _ in enumerate(paths_to_cat):
+ past_ep_idx = ep_idx - len(paths_to_cat) + i
+ episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
+ episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
+
+ # Move to next file and start fresh with current episode
+ chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
+ size_in_mb = 0
+ duration_in_s = 0.0
+ paths_to_cat = []
+
+ # Add current episode metadata
+ ep_metadata = {
+ "episode_index": ep_idx,
+ f"videos/{video_key}/chunk_index": chunk_idx, # Will be updated when file is saved
+ f"videos/{video_key}/file_index": file_idx, # Will be updated when file is saved
+ f"videos/{video_key}/from_timestamp": duration_in_s,
+ f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
+ }
+ episodes_metadata.append(ep_metadata)
+
+ # Add current episode to accumulation
+ paths_to_cat.append(ep_path)
+ size_in_mb += ep_size_in_mb
+ duration_in_s += ep_duration_in_s
+ ep_idx += 1
+
+ # Write remaining videos if any
+ if paths_to_cat:
+ concatenate_video_files(
+ paths_to_cat,
+ new_root
+ / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
+ )
+
+ # Update episodes metadata for the final file
+ for i, _ in enumerate(paths_to_cat):
+ past_ep_idx = ep_idx - len(paths_to_cat) + i
+ episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx
+ episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx
+
+ return episodes_metadata
+
+
+def generate_episode_metadata_dict(
+ episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
+):
+ num_episodes = len(episodes_metadata)
+ episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
+ episodes_stats_vals = list(episodes_stats.values())
+ episodes_stats_keys = list(episodes_stats.keys())
+
+ for i in range(num_episodes):
+ ep_legacy_metadata = episodes_legacy_metadata_vals[i]
+ ep_metadata = episodes_metadata[i]
+ ep_stats = episodes_stats_vals[i]
+
+ ep_ids_set = {
+ ep_legacy_metadata["episode_index"],
+ ep_metadata["episode_index"],
+ episodes_stats_keys[i],
+ }
+
+ if episodes_videos is None:
+ ep_video = {}
+ else:
+ ep_video = episodes_videos[i]
+ ep_ids_set.add(ep_video["episode_index"])
+
+ if len(ep_ids_set) != 1:
+ raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
+
+ ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
+ ep_dict["meta/episodes/chunk_index"] = 0
+ ep_dict["meta/episodes/file_index"] = 0
+ yield ep_dict
+
+
+def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
+ logging.info(f"Converting episodes metadata from {root} to {new_root}")
+
+ episodes_legacy_metadata = legacy_load_episodes(root)
+ episodes_stats = legacy_load_episodes_stats(root)
+
+ num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
+ if episodes_video_metadata is not None:
+ num_eps_set.add(len(episodes_video_metadata))
+
+ if len(num_eps_set) != 1:
+ raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
+
+ ds_episodes = Dataset.from_generator(
+ lambda: generate_episode_metadata_dict(
+ episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
+ )
+ )
+ write_episodes(ds_episodes, new_root)
+
+ stats = aggregate_stats(list(episodes_stats.values()))
+ write_stats(stats, new_root)
+
+
+def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
+ info = load_info(root)
+ info["codebase_version"] = V30
+ del info["total_chunks"]
+ del info["total_videos"]
+ info["data_files_size_in_mb"] = data_file_size_in_mb
+ info["video_files_size_in_mb"] = video_file_size_in_mb
+ info["data_path"] = DEFAULT_DATA_PATH
+ info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
+ info["fps"] = int(info["fps"])
+ logging.info(f"Converting info from {root} to {new_root}")
+ for key in info["features"]:
+ if info["features"][key]["dtype"] == "video":
+ # already has fps in video_info
+ continue
+ info["features"][key]["fps"] = info["fps"]
+ write_info(info, new_root)
+
+
+def convert_dataset(
+ repo_id: str,
+ branch: str | None = None,
+ data_file_size_in_mb: int | None = None,
+ video_file_size_in_mb: int | None = None,
+ root: str | Path | None = None,
+ push_to_hub: bool = True,
+ force_conversion: bool = False,
+):
+ if data_file_size_in_mb is None:
+ data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
+ if video_file_size_in_mb is None:
+ video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
+
+ # First check if the dataset already has a v3.0 version
+ if root is None and not force_conversion:
+ try:
+ print("Trying to download v3.0 version of the dataset from the hub...")
+ snapshot_download(repo_id, repo_type="dataset", revision=V30, local_dir=HF_LEROBOT_HOME / repo_id)
+ return
+ except Exception:
+ print("Dataset does not have an uploaded v3.0 version. Continuing with conversion.")
+
+ # Set root based on whether local dataset path is provided
+ use_local_dataset = False
+ root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
+ if root.exists():
+ validate_local_dataset_version(root)
+ use_local_dataset = True
+ print(f"Using local dataset at {root}")
+
+ old_root = root.parent / f"{root.name}_old"
+ new_root = root.parent / f"{root.name}_v30"
+
+ # Handle old_root cleanup if both old_root and root exist
+ if old_root.is_dir() and root.is_dir():
+ shutil.rmtree(str(root))
+ shutil.move(str(old_root), str(root))
+
+ if new_root.is_dir():
+ shutil.rmtree(new_root)
+
+ if not use_local_dataset:
+ snapshot_download(
+ repo_id,
+ repo_type="dataset",
+ revision=V21,
+ local_dir=root,
+ )
+
+ convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb)
+ convert_tasks(root, new_root)
+ episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
+ episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
+ convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
+
+ shutil.move(str(root), str(old_root))
+ shutil.move(str(new_root), str(root))
+
+ if push_to_hub:
+ hub_api = HfApi()
+ try:
+ hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
+ except HTTPError as e:
+ print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
+ pass
+ hub_api.delete_files(
+ delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
+ repo_id=repo_id,
+ revision=branch,
+ repo_type="dataset",
+ )
+ hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
+
+ LeRobotDataset(repo_id).push_to_hub()
+
+
+if __name__ == "__main__":
+ init_logging()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ required=True,
+ help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
+ "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
+ )
+ parser.add_argument(
+ "--branch",
+ type=str,
+ default=None,
+ help="Repo branch to push your dataset. Defaults to the main branch.",
+ )
+ parser.add_argument(
+ "--data-file-size-in-mb",
+ type=int,
+ default=None,
+ help="File size in MB. Defaults to 100 for data and 500 for videos.",
+ )
+ parser.add_argument(
+ "--video-file-size-in-mb",
+ type=int,
+ default=None,
+ help="File size in MB. Defaults to 100 for data and 500 for videos.",
+ )
+ parser.add_argument(
+ "--root",
+ type=str,
+ default=None,
+ help="Local directory to use for downloading/writing the dataset.",
+ )
+ parser.add_argument(
+ "--push-to-hub",
+ type=lambda input: input.lower() == "true",
+ default=True,
+ help="Push the converted dataset to the hub.",
+ )
+ parser.add_argument(
+ "--force-conversion",
+ action="store_true",
+ help="Force conversion even if the dataset already has a v3.0 version.",
+ )
+
+ args = parser.parse_args()
+ convert_dataset(**vars(args))
diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py
index 3a77f36e4..1d4f07c76 100644
--- a/src/lerobot/datasets/video_utils.py
+++ b/src/lerobot/datasets/video_utils.py
@@ -16,12 +16,16 @@
import glob
import importlib
import logging
+import shutil
+import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
+from threading import Lock
from typing import Any, ClassVar
import av
+import fsspec
import pyarrow as pa
import torch
import torchvision
@@ -167,15 +171,68 @@ def decode_video_frames_torchvision(
return closest_frames
+class VideoDecoderCache:
+ """Thread-safe cache for video decoders to avoid expensive re-initialization."""
+
+ def __init__(self):
+ self._cache: dict[str, tuple[Any, Any]] = {}
+ self._lock = Lock()
+
+ def get_decoder(self, video_path: str):
+ """Get a cached decoder or create a new one."""
+ if importlib.util.find_spec("torchcodec"):
+ from torchcodec.decoders import VideoDecoder
+ else:
+ raise ImportError("torchcodec is required but not available.")
+
+ video_path = str(video_path)
+
+ with self._lock:
+ if video_path not in self._cache:
+ file_handle = fsspec.open(video_path).__enter__()
+ decoder = VideoDecoder(file_handle, seek_mode="approximate")
+ self._cache[video_path] = (decoder, file_handle)
+
+ return self._cache[video_path][0]
+
+ def clear(self):
+ """Clear the cache and close file handles."""
+ with self._lock:
+ for _, file_handle in self._cache.values():
+ file_handle.close()
+ self._cache.clear()
+
+ def size(self) -> int:
+ """Return the number of cached decoders."""
+ with self._lock:
+ return len(self._cache)
+
+
+class FrameTimestampError(ValueError):
+ """Helper error to indicate the retrieved timestamps exceed the queried ones"""
+
+ pass
+
+
+_default_decoder_cache = VideoDecoderCache()
+
+
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
- device: str = "cpu",
log_loaded_timestamps: bool = False,
+ decoder_cache: VideoDecoderCache | None = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
+ Args:
+ video_path: Path to the video file.
+ timestamps: List of timestamps to extract frames.
+ tolerance_s: Allowed deviation in seconds for frame retrieval.
+ log_loaded_timestamps: Whether to log loaded timestamps.
+ decoder_cache: Optional decoder cache instance. Uses default if None.
+
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
@@ -184,27 +241,24 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
+ if decoder_cache is None:
+ decoder_cache = _default_decoder_cache
- if importlib.util.find_spec("torchcodec"):
- from torchcodec.decoders import VideoDecoder
- else:
- raise ImportError("torchcodec is required but not available.")
+ # Use cached decoder instead of creating new one each time
+ decoder = decoder_cache.get_decoder(str(video_path))
- # initialize video decoder
- decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
- loaded_frames = []
loaded_ts = []
+ loaded_frames = []
+
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
-
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
-
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
- for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
+ for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
@@ -235,10 +289,14 @@ def decode_video_frames_torchcodec(
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
- # convert to float32 in [0,1] range (channel first)
- closest_frames = closest_frames.type(torch.float32) / 255
+ # convert to float32 in [0,1] range
+ closest_frames = (closest_frames / 255.0).type(torch.float32)
+
+ if not len(timestamps) == len(closest_frames):
+ raise FrameTimestampError(
+ f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
+ )
- assert len(timestamps) == len(closest_frames)
return closest_frames
@@ -262,7 +320,11 @@ def encode_video_frames(
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
- video_path.parent.mkdir(parents=True, exist_ok=overwrite)
+ if video_path.exists() and not overwrite:
+ logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
+ return
+
+ video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
@@ -272,9 +334,9 @@ def encode_video_frames(
pix_fmt = "yuv420p"
# Get input frames
- template = "frame_" + ("[0-9]" * 6) + ".png"
+ template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(
- glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0])
+ glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
# Define video output frame size (assuming all input frames are the same size)
@@ -299,7 +361,7 @@ def encode_video_frames(
# Set logging level
if log_level is not None:
- # "While less efficient, it is generally preferable to modify logging with Python’s logging"
+ # "While less efficient, it is generally preferable to modify logging with Python's logging"
logging.getLogger("libav").setLevel(log_level)
# Create and open output file (overwrite by default)
@@ -330,6 +392,86 @@ def encode_video_frames(
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
+def concatenate_video_files(
+ input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
+):
+ """
+ Concatenate multiple video files into a single video file using pyav.
+
+ This function takes a list of video input file paths and concatenates them into a single
+ output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
+ concatenation without re-encoding.
+
+ Args:
+ input_video_paths: Ordered list of input video file paths to concatenate.
+ output_video_path: Path to the output video file.
+ overwrite: Whether to overwrite the output video file if it already exists. Default is True.
+
+ Note:
+ - Creates a temporary directory for intermediate files that is cleaned up after use.
+ - Uses ffmpeg's concat demuxer which requires all input videos to have the same
+ codec, resolution, and frame rate for proper concatenation.
+ """
+
+ output_video_path = Path(output_video_path)
+
+ if output_video_path.exists() and not overwrite:
+ logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
+ return
+
+ output_video_path.parent.mkdir(parents=True, exist_ok=True)
+
+ if len(input_video_paths) == 0:
+ raise FileNotFoundError("No input video paths provided.")
+
+ # Create a temporary .ffconcat file to list the input video paths
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
+ tmp_concatenate_file.write("ffconcat version 1.0\n")
+ for input_path in input_video_paths:
+ tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
+ tmp_concatenate_file.flush()
+ tmp_concatenate_path = tmp_concatenate_file.name
+
+ # Create input and output containers
+ input_container = av.open(
+ tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
+ ) # safe = 0 allows absolute paths as well as relative paths
+
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
+ tmp_output_video_path = tmp_named_file.name
+
+ output_container = av.open(
+ tmp_output_video_path, mode="w", options={"movflags": "faststart"}
+ ) # faststart is to move the metadata to the beginning of the file to speed up loading
+
+ # Replicate input streams in output container
+ stream_map = {}
+ for input_stream in input_container.streams:
+ if input_stream.type in ("video", "audio", "subtitle"): # only copy compatible streams
+ stream_map[input_stream.index] = output_container.add_stream_from_template(
+ template=input_stream, opaque=True
+ )
+
+ # Demux + remux packets (no re-encode)
+ for packet in input_container.demux():
+ # Skip packets from un-mapped streams
+ if packet.stream.index not in stream_map:
+ continue
+
+ # Skip demux flushing packets
+ if packet.dts is None:
+ continue
+
+ output_stream = stream_map[packet.stream.index]
+ packet.stream = output_stream
+ output_container.mux(packet)
+
+ input_container.close()
+ output_container.close()
+ shutil.move(tmp_output_video_path, output_video_path)
+ Path(tmp_concatenate_path).unlink()
+
+
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
@@ -440,14 +582,86 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
raise ValueError("Unknown format")
-def get_image_pixel_channels(image: Image):
- if image.mode == "L":
- return 1 # Grayscale
- elif image.mode == "LA":
- return 2 # Grayscale + Alpha
- elif image.mode == "RGB":
- return 3 # RGB
- elif image.mode == "RGBA":
- return 4 # RGBA
- else:
- raise ValueError("Unknown format")
+def get_video_duration_in_s(video_path: Path | str) -> float:
+ """
+ Get the duration of a video file in seconds using PyAV.
+
+ Args:
+ video_path: Path to the video file.
+
+ Returns:
+ Duration of the video in seconds.
+ """
+ with av.open(str(video_path)) as container:
+ # Get the first video stream
+ video_stream = container.streams.video[0]
+ # Calculate duration: stream.duration * stream.time_base gives duration in seconds
+ if video_stream.duration is not None:
+ duration = float(video_stream.duration * video_stream.time_base)
+ else:
+ # Fallback to container duration if stream duration is not available
+ duration = float(container.duration / av.time_base)
+ return duration
+
+
+class VideoEncodingManager:
+ """
+ Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
+
+ This manager handles:
+ - Batch encoding for any remaining episodes when recording interrupted
+ - Cleaning up temporary image files from interrupted episodes
+ - Removing empty image directories
+
+ Args:
+ dataset: The LeRobotDataset instance
+ """
+
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ # Handle any remaining episodes that haven't been batch encoded
+ if self.dataset.episodes_since_last_encoding > 0:
+ if exc_type is not None:
+ logging.info("Exception occurred. Encoding remaining episodes before exit...")
+ else:
+ logging.info("Recording stopped. Encoding remaining episodes...")
+
+ start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
+ end_ep = self.dataset.num_episodes
+ logging.info(
+ f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
+ f"from episode {start_ep} to {end_ep - 1}"
+ )
+ self.dataset._batch_save_episode_video(start_ep, end_ep)
+
+ # Clean up episode images if recording was interrupted
+ if exc_type is not None:
+ interrupted_episode_index = self.dataset.num_episodes
+ for key in self.dataset.meta.video_keys:
+ img_dir = self.dataset._get_image_file_path(
+ episode_index=interrupted_episode_index, image_key=key, frame_index=0
+ ).parent
+ if img_dir.exists():
+ logging.debug(
+ f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
+ )
+ shutil.rmtree(img_dir)
+
+ # Clean up any remaining images directory if it's empty
+ img_dir = self.dataset.root / "images"
+ # Check for any remaining PNG files
+ png_files = list(img_dir.rglob("*.png"))
+ if len(png_files) == 0:
+ # Only remove the images directory if no PNG files remain
+ if img_dir.exists():
+ shutil.rmtree(img_dir)
+ logging.debug("Cleaned up empty images directory")
+ else:
+ logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
+
+ return False # Don't suppress the original exception
diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py
index de969d618..0daaaf9fd 100644
--- a/src/lerobot/envs/configs.py
+++ b/src/lerobot/envs/configs.py
@@ -14,14 +14,14 @@
import abc
from dataclasses import dataclass, field
-from typing import Any, Optional
+from typing import Any
import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
+from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
@dataclass
@@ -30,6 +30,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
fps: int = 30
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)
+ max_parallel_tasks: int = 1
+ disable_env_checker: bool = True
@property
def type(self) -> str:
@@ -44,19 +46,19 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
@EnvConfig.register_subclass("aloha")
@dataclass
class AlohaEnv(EnvConfig):
- task: str = "AlohaInsertion-v0"
+ task: str | None = "AlohaInsertion-v0"
fps: int = 50
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
- "action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
- "action": ACTION,
+ ACTION: ACTION,
"agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
@@ -82,7 +84,7 @@ class AlohaEnv(EnvConfig):
@EnvConfig.register_subclass("pusht")
@dataclass
class PushtEnv(EnvConfig):
- task: str = "PushT-v0"
+ task: str | None = "PushT-v0"
fps: int = 10
episode_length: int = 300
obs_type: str = "pixels_agent_pos"
@@ -91,13 +93,13 @@ class PushtEnv(EnvConfig):
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
- "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
- "action": ACTION,
+ ACTION: ACTION,
"agent_pos": OBS_STATE,
"environment_state": OBS_ENV_STATE,
"pixels": OBS_IMAGE,
@@ -124,7 +126,7 @@ class PushtEnv(EnvConfig):
@EnvConfig.register_subclass("xarm")
@dataclass
class XarmEnv(EnvConfig):
- task: str = "XarmLift-v0"
+ task: str | None = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
obs_type: str = "pixels_agent_pos"
@@ -133,13 +135,13 @@ class XarmEnv(EnvConfig):
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
- "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
- "action": ACTION,
+ ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
}
@@ -161,33 +163,69 @@ class XarmEnv(EnvConfig):
@dataclass
-class VideoRecordConfig:
- """Configuration for video recording in ManiSkill environments."""
-
- enabled: bool = False
- record_dir: str = "videos"
- trajectory_name: str = "trajectory"
+class ImagePreprocessingConfig:
+ crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
+ resize_size: tuple[int, int] | None = None
@dataclass
-class EnvTransformConfig:
- """Configuration for environment wrappers."""
+class RewardClassifierConfig:
+ """Configuration for reward classification."""
+
+ pretrained_path: str | None = None
+ success_threshold: float = 0.5
+ success_reward: float = 1.0
+
+
+@dataclass
+class InverseKinematicsConfig:
+ """Configuration for inverse kinematics processing."""
+
+ urdf_path: str | None = None
+ target_frame_name: str | None = None
+ end_effector_bounds: dict[str, list[float]] | None = None
+ end_effector_step_sizes: dict[str, float] | None = None
+
+
+@dataclass
+class ObservationConfig:
+ """Configuration for observation processing."""
- # ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
- control_mode: str = "gamepad"
- display_cameras: bool = False
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
- add_ee_pose_to_observation: bool = False
- crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None
- resize_size: Optional[tuple[int, int]] = None
- control_time_s: float = 20.0
- fixed_reset_joint_positions: Optional[Any] = None
- reset_time_s: float = 5.0
+ display_cameras: bool = False
+
+
+@dataclass
+class GripperConfig:
+ """Configuration for gripper control and penalties."""
+
use_gripper: bool = True
- gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
- gripper_penalty_in_reward: bool = False
+
+
+@dataclass
+class ResetConfig:
+ """Configuration for environment reset behavior."""
+
+ fixed_reset_joint_positions: Any | None = None
+ reset_time_s: float = 5.0
+ control_time_s: float = 20.0
+ terminate_on_success: bool = True
+
+
+@dataclass
+class HILSerlProcessorConfig:
+ """Configuration for environment processing pipeline."""
+
+ control_mode: str = "gamepad"
+ observation: ObservationConfig | None = None
+ image_preprocessing: ImagePreprocessingConfig | None = None
+ gripper: GripperConfig | None = None
+ reset: ResetConfig | None = None
+ inverse_kinematics: InverseKinematicsConfig | None = None
+ reward_classifier: RewardClassifierConfig | None = None
+ max_gripper_pos: float | None = 100.0
@EnvConfig.register_subclass(name="gym_manipulator")
@@ -195,79 +233,64 @@ class EnvTransformConfig:
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
- robot: Optional[RobotConfig] = None
- teleop: Optional[TeleoperatorConfig] = None
- wrapper: Optional[EnvTransformConfig] = None
- fps: int = 10
- name: str = "real_robot"
- mode: str = None # Either "record", "replay", None
- repo_id: Optional[str] = None
- dataset_root: Optional[str] = None
- task: str = ""
- num_episodes: int = 10 # only for record mode
- episode: int = 0
- device: str = "cuda"
- push_to_hub: bool = True
- pretrained_policy_name_or_path: Optional[str] = None
- reward_classifier_pretrained_path: Optional[str] = None
- # For the reward classifier, to record more positive examples after a success
- number_of_steps_after_success: int = 0
+ robot: RobotConfig | None = None
+ teleop: TeleoperatorConfig | None = None
+ processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
+ name: str = "real_robot"
+
+ @property
def gym_kwargs(self) -> dict:
return {}
-@EnvConfig.register_subclass("hil")
+@EnvConfig.register_subclass("libero")
@dataclass
-class HILEnvConfig(EnvConfig):
- """Configuration for the HIL environment."""
-
- type: str = "hil"
- name: str = "PandaPickCube"
- task: str = "PandaPickCubeKeyboard-v0"
- use_viewer: bool = True
- gripper_penalty: float = 0.0
- use_gamepad: bool = True
- state_dim: int = 18
- action_dim: int = 4
- fps: int = 100
- episode_length: int = 100
- video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
+class LiberoEnv(EnvConfig):
+ task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
+ fps: int = 30
+ episode_length: int = 520
+ obs_type: str = "pixels_agent_pos"
+ render_mode: str = "rgb_array"
+ camera_name: str = "agentview_image,robot0_eye_in_hand_image"
+ init_states: bool = True
+ camera_name_mapping: dict[str, str] | None = None
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
- "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
- "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
- "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
- "action": ACTION,
- "observation.image": OBS_IMAGE,
- "observation.state": OBS_STATE,
+ ACTION: ACTION,
+ "agent_pos": OBS_STATE,
+ "pixels/agentview_image": f"{OBS_IMAGES}.image",
+ "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
}
)
- ################# args from hilserlrobotenv
- reward_classifier_pretrained_path: Optional[str] = None
- robot_config: Optional[RobotConfig] = None
- teleop_config: Optional[TeleoperatorConfig] = None
- wrapper: Optional[EnvTransformConfig] = None
- mode: str = None # Either "record", "replay", None
- repo_id: Optional[str] = None
- dataset_root: Optional[str] = None
- num_episodes: int = 10 # only for record mode
- episode: int = 0
- device: str = "cuda"
- push_to_hub: bool = True
- pretrained_policy_name_or_path: Optional[str] = None
- # For the reward classifier, to record more positive examples after a success
- number_of_steps_after_success: int = 0
- ############################
+
+ def __post_init__(self):
+ if self.obs_type == "pixels":
+ self.features["pixels/agentview_image"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(360, 360, 3)
+ )
+ self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(360, 360, 3)
+ )
+ elif self.obs_type == "pixels_agent_pos":
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
+ self.features["pixels/agentview_image"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(360, 360, 3)
+ )
+ self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(360, 360, 3)
+ )
+ else:
+ raise ValueError(f"Unsupported obs_type: {self.obs_type}")
@property
def gym_kwargs(self) -> dict:
return {
- "use_viewer": self.use_viewer,
- "use_gamepad": self.use_gamepad,
- "gripper_penalty": self.gripper_penalty,
+ "obs_type": self.obs_type,
+ "render_mode": self.render_mode,
}
diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py
index dc6d96d61..c27f01b65 100644
--- a/src/lerobot/envs/factory.py
+++ b/src/lerobot/envs/factory.py
@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
-from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
+from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,13 +27,15 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
- elif env_type == "hil":
- return HILEnvConfig(**kwargs)
+ elif env_type == "libero":
+ return LiberoEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
-def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
+def make_env(
+ cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
+) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config.
Args:
@@ -47,13 +49,33 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
ModuleNotFoundError: If the requested env package is not installed
Returns:
- gym.vector.VectorEnv: The parallelized gym.env instance.
+ dict[str, dict[int, gym.vector.VectorEnv]]:
+ A mapping from suite name to indexed vectorized environments.
+ - For multi-task benchmarks (e.g., LIBERO): one entry per suite, and one vec env per task_id.
+ - For single-task environments: a single suite entry (cfg.type) with task_id=0.
+
"""
if n_envs < 1:
- raise ValueError("`n_envs must be at least 1")
+ raise ValueError("`n_envs` must be at least 1")
+
+ env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
+
+ if "libero" in cfg.type:
+ from lerobot.envs.libero import create_libero_envs
+
+ if cfg.task is None:
+ raise ValueError("LiberoEnv requires a task to be specified")
+
+ return create_libero_envs(
+ task=cfg.task,
+ n_envs=n_envs,
+ camera_name=cfg.camera_name,
+ init_states=cfg.init_states,
+ gym_kwargs=cfg.gym_kwargs,
+ env_cls=env_cls,
+ )
package_name = f"gym_{cfg.type}"
-
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
@@ -62,10 +84,11 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
gym_handle = f"{package_name}/{cfg.task}"
- # batched version of the env that returns an observation of shape (b, c)
- env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
- env = env_cls(
- [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
- )
+ def _make_one():
+ return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
- return env
+ vec = env_cls([_make_one for _ in range(n_envs)])
+
+ # normalize to {suite: {task_id: vec_env}} for consistency
+ suite_name = cfg.type # e.g., "pusht", "aloha"
+ return {suite_name: {0: vec}}
diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py
new file mode 100644
index 000000000..99ec6712f
--- /dev/null
+++ b/src/lerobot/envs/libero.py
@@ -0,0 +1,377 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import os
+from collections import defaultdict
+from collections.abc import Callable, Iterable, Mapping, Sequence
+from functools import partial
+from pathlib import Path
+from typing import Any
+
+import gymnasium as gym
+import numpy as np
+import torch
+from gymnasium import spaces
+from libero.libero import benchmark, get_libero_path
+from libero.libero.envs import OffScreenRenderEnv
+from robosuite.utils.transform_utils import quat2axisangle
+
+
+def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
+ """Normalize camera_name into a non-empty list of strings."""
+ if isinstance(camera_name, str):
+ cams = [c.strip() for c in camera_name.split(",") if c.strip()]
+ elif isinstance(camera_name, (list | tuple)):
+ cams = [str(c).strip() for c in camera_name if str(c).strip()]
+ else:
+ raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
+ if not cams:
+ raise ValueError("camera_name resolved to an empty list.")
+ return cams
+
+
+def _get_suite(name: str) -> benchmark.Benchmark:
+ """Instantiate a LIBERO suite by name with clear validation."""
+ bench = benchmark.get_benchmark_dict()
+ if name not in bench:
+ raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}")
+ suite = bench[name]()
+ if not getattr(suite, "tasks", None):
+ raise ValueError(f"Suite '{name}' has no tasks.")
+ return suite
+
+
+def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]:
+ """Validate/normalize task ids. If None → all tasks."""
+ if task_ids is None:
+ return list(range(total_tasks))
+ ids = sorted({int(t) for t in task_ids})
+ for t in ids:
+ if t < 0 or t >= total_tasks:
+ raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].")
+ return ids
+
+
+def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
+ init_states_path = (
+ Path(get_libero_path("init_states"))
+ / task_suite.tasks[i].problem_folder
+ / task_suite.tasks[i].init_states_file
+ )
+ init_states = torch.load(init_states_path, weights_only=False) # nosec B614
+ return init_states
+
+
+def get_libero_dummy_action():
+ """Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
+ return [0, 0, 0, 0, 0, 0, -1]
+
+
+OBS_STATE_DIM = 8
+ACTION_DIM = 7
+AGENT_POS_LOW = -1000.0
+AGENT_POS_HIGH = 1000.0
+ACTION_LOW = -1.0
+ACTION_HIGH = 1.0
+TASK_SUITE_MAX_STEPS: dict[str, int] = {
+ "libero_spatial": 280, # longest training demo has 193 steps
+ "libero_object": 280, # longest training demo has 254 steps
+ "libero_goal": 300, # longest training demo has 270 steps
+ "libero_10": 520, # longest training demo has 505 steps
+ "libero_90": 400, # longest training demo has 373 steps
+}
+
+
+class LiberoEnv(gym.Env):
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
+
+ def __init__(
+ self,
+ task_suite: Any,
+ task_id: int,
+ task_suite_name: str,
+ camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
+ obs_type: str = "pixels",
+ render_mode: str = "rgb_array",
+ observation_width: int = 256,
+ observation_height: int = 256,
+ visualization_width: int = 640,
+ visualization_height: int = 480,
+ init_states: bool = True,
+ episode_index: int = 0,
+ camera_name_mapping: dict[str, str] | None = None,
+ num_steps_wait: int = 10,
+ ):
+ super().__init__()
+ self.task_id = task_id
+ self.obs_type = obs_type
+ self.render_mode = render_mode
+ self.observation_width = observation_width
+ self.observation_height = observation_height
+ self.visualization_width = visualization_width
+ self.visualization_height = visualization_height
+ self.init_states = init_states
+ self.camera_name = _parse_camera_names(
+ camera_name
+ ) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
+
+ # Map raw camera names to "image1" and "image2".
+ # The preprocessing step `preprocess_observation` will then prefix these with `.images.*`,
+ # following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`).
+ # This ensures the policy consistently receives observations in the
+ # expected format regardless of the original camera naming.
+ if camera_name_mapping is None:
+ camera_name_mapping = {
+ "agentview_image": "image",
+ "robot0_eye_in_hand_image": "image2",
+ }
+ self.camera_name_mapping = camera_name_mapping
+ self.num_steps_wait = num_steps_wait
+ self.episode_index = episode_index
+ # Load once and keep
+ self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
+ self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
+
+ self._env = self._make_envs_task(task_suite, self.task_id)
+ default_steps = 500
+ self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
+
+ images = {}
+ for cam in self.camera_name:
+ images[self.camera_name_mapping[cam]] = spaces.Box(
+ low=0,
+ high=255,
+ shape=(self.observation_height, self.observation_width, 3),
+ dtype=np.uint8,
+ )
+
+ if self.obs_type == "state":
+ raise NotImplementedError(
+ "The 'state' observation type is not supported in LiberoEnv. "
+ "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
+ )
+
+ elif self.obs_type == "pixels":
+ self.observation_space = spaces.Dict(
+ {
+ "pixels": spaces.Dict(images),
+ }
+ )
+ elif self.obs_type == "pixels_agent_pos":
+ self.observation_space = spaces.Dict(
+ {
+ "pixels": spaces.Dict(images),
+ "agent_pos": spaces.Box(
+ low=AGENT_POS_LOW,
+ high=AGENT_POS_HIGH,
+ shape=(OBS_STATE_DIM,),
+ dtype=np.float64,
+ ),
+ }
+ )
+
+ self.action_space = spaces.Box(
+ low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
+ )
+
+ def render(self):
+ raw_obs = self._env.env._get_observations()
+ image = self._format_raw_obs(raw_obs)["pixels"]["image"]
+ return image
+
+ def _make_envs_task(self, task_suite: Any, task_id: int = 0):
+ task = task_suite.get_task(task_id)
+ self.task = task.name
+ self.task_description = task.language
+ task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
+
+ env_args = {
+ "bddl_file_name": task_bddl_file,
+ "camera_heights": self.observation_height,
+ "camera_widths": self.observation_width,
+ }
+ env = OffScreenRenderEnv(**env_args)
+ env.reset()
+ return env
+
+ def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]:
+ images = {}
+ for camera_name in self.camera_name:
+ image = raw_obs[camera_name]
+ image = image[::-1, ::-1] # rotate 180 degrees
+ images[self.camera_name_mapping[camera_name]] = image
+ state = np.concatenate(
+ (
+ raw_obs["robot0_eef_pos"],
+ quat2axisangle(raw_obs["robot0_eef_quat"]),
+ raw_obs["robot0_gripper_qpos"],
+ )
+ )
+ agent_pos = state
+ if self.obs_type == "pixels":
+ return {"pixels": images.copy()}
+ if self.obs_type == "pixels_agent_pos":
+ return {
+ "pixels": images.copy(),
+ "agent_pos": agent_pos,
+ }
+ raise NotImplementedError(
+ f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
+ "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
+ )
+
+ def reset(self, seed=None, **kwargs):
+ super().reset(seed=seed)
+ self._env.seed(seed)
+ if self.init_states and self._init_states is not None:
+ self._env.set_init_state(self._init_states[self._init_state_id])
+ raw_obs = self._env.reset()
+
+ # After reset, objects may be unstable (slightly floating, intersecting, etc.).
+ # Step the simulator with a no-op action for a few frames so everything settles.
+ # Increasing this value can improve determinism and reproducibility across resets.
+ for _ in range(self.num_steps_wait):
+ raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
+ observation = self._format_raw_obs(raw_obs)
+ info = {"is_success": False}
+ return observation, info
+
+ def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
+ if action.ndim != 1:
+ raise ValueError(
+ f"Expected action to be 1-D (shape (action_dim,)), "
+ f"but got shape {action.shape} with ndim={action.ndim}"
+ )
+ raw_obs, reward, done, info = self._env.step(action)
+
+ is_success = self._env.check_success()
+ terminated = done or is_success
+ info["is_success"] = is_success
+
+ observation = self._format_raw_obs(raw_obs)
+ if done:
+ self.reset()
+ info.update(
+ {
+ "task": self.task,
+ "task_id": self.task_id,
+ "done": done,
+ "is_success": is_success,
+ }
+ )
+ truncated = False
+ return observation, reward, terminated, truncated, info
+
+ def close(self):
+ self._env.close()
+
+
+def _make_env_fns(
+ *,
+ suite,
+ suite_name: str,
+ task_id: int,
+ n_envs: int,
+ camera_names: list[str],
+ init_states: bool,
+ gym_kwargs: Mapping[str, Any],
+) -> list[Callable[[], LiberoEnv]]:
+ """Build n_envs factory callables for a single (suite, task_id)."""
+
+ def _make_env(episode_index: int, **kwargs) -> LiberoEnv:
+ local_kwargs = dict(kwargs)
+ return LiberoEnv(
+ task_suite=suite,
+ task_id=task_id,
+ task_suite_name=suite_name,
+ camera_name=camera_names,
+ init_states=init_states,
+ episode_index=episode_index,
+ **local_kwargs,
+ )
+
+ fns: list[Callable[[], LiberoEnv]] = []
+ for episode_index in range(n_envs):
+ fns.append(partial(_make_env, episode_index, **gym_kwargs))
+ return fns
+
+
+# ---- Main API ----------------------------------------------------------------
+
+
+def create_libero_envs(
+ task: str,
+ n_envs: int,
+ gym_kwargs: dict[str, Any] | None = None,
+ camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
+ init_states: bool = True,
+ env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
+) -> dict[str, dict[int, Any]]:
+ """
+ Create vectorized LIBERO environments with a consistent return shape.
+
+ Returns:
+ dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
+ Notes:
+ - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
+ - `task` can be a single suite or a comma-separated list of suites.
+ - You may pass `task_ids` (list[int]) inside `gym_kwargs` to restrict tasks per suite.
+ """
+ if env_cls is None or not callable(env_cls):
+ raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
+ if not isinstance(n_envs, int) or n_envs <= 0:
+ raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
+
+ gym_kwargs = dict(gym_kwargs or {})
+ task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks
+
+ camera_names = _parse_camera_names(camera_name)
+ suite_names = [s.strip() for s in str(task).split(",") if s.strip()]
+ if not suite_names:
+ raise ValueError("`task` must contain at least one LIBERO suite name.")
+
+ print(
+ f"Creating LIBERO envs | suites={suite_names} | n_envs(per task)={n_envs} | init_states={init_states}"
+ )
+ if task_ids_filter is not None:
+ print(f"Restricting to task_ids={task_ids_filter}")
+
+ out: dict[str, dict[int, Any]] = defaultdict(dict)
+
+ for suite_name in suite_names:
+ suite = _get_suite(suite_name)
+ total = len(suite.tasks)
+ selected = _select_task_ids(total, task_ids_filter)
+
+ if not selected:
+ raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
+
+ for tid in selected:
+ fns = _make_env_fns(
+ suite=suite,
+ suite_name=suite_name,
+ task_id=tid,
+ n_envs=n_envs,
+ camera_names=camera_names,
+ init_states=init_states,
+ gym_kwargs=gym_kwargs,
+ )
+ out[suite_name][tid] = env_cls(fns)
+ print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
+
+ # return plain dicts for predictability
+ return {suite: dict(task_map) for suite, task_map in out.items()}
diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py
index 00676a011..5584e0bff 100644
--- a/src/lerobot/envs/utils.py
+++ b/src/lerobot/envs/utils.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
+from collections.abc import Mapping, Sequence
+from functools import singledispatch
from typing import Any
import einops
@@ -24,6 +26,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
+from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.utils import get_channel_first_image_shape
@@ -39,44 +42,44 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
- imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
+ imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()}
else:
- imgs = {"observation.image": observations["pixels"]}
+ imgs = {OBS_IMAGE: observations["pixels"]}
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
- img = torch.from_numpy(img)
+ img_tensor = torch.from_numpy(img)
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
- if img.ndim == 3:
- img = img.unsqueeze(0)
+ if img_tensor.ndim == 3:
+ img_tensor = img_tensor.unsqueeze(0)
# sanity check that images are channel last
- _, h, w, c = img.shape
- assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
+ _, h, w, c = img_tensor.shape
+ assert c < h and c < w, f"expect channel last images, but instead got {img_tensor.shape=}"
# sanity check that images are uint8
- assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
+ assert img_tensor.dtype == torch.uint8, f"expect torch.uint8, but instead {img_tensor.dtype=}"
# convert to channel first of type float32 in range [0,1]
- img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
- img = img.type(torch.float32)
- img /= 255
+ img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
+ img_tensor = img_tensor.type(torch.float32)
+ img_tensor /= 255
- return_observations[imgkey] = img
+ return_observations[imgkey] = img_tensor
if "environment_state" in observations:
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
- return_observations["observation.environment_state"] = env_state
+ return_observations[OBS_ENV_STATE] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
- return_observations["observation.state"] = agent_pos
+ return_observations[OBS_STATE] = agent_pos
return return_observations
@@ -127,10 +130,68 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
- observation["task"] = env.call("task_description")
+ task_result = env.call("task_description")
+
+ if isinstance(task_result, tuple):
+ task_result = list(task_result)
+
+ if not isinstance(task_result, list):
+ raise TypeError(f"Expected task_description to return a list, got {type(task_result)}")
+ if not all(isinstance(item, str) for item in task_result):
+ raise TypeError("All items in task_description result must be strings")
+
+ observation["task"] = task_result
elif hasattr(env.envs[0], "task"):
- observation["task"] = env.call("task")
+ task_result = env.call("task")
+
+ if isinstance(task_result, tuple):
+ task_result = list(task_result)
+
+ if not isinstance(task_result, list):
+ raise TypeError(f"Expected task to return a list, got {type(task_result)}")
+ if not all(isinstance(item, str) for item in task_result):
+ raise TypeError("All items in task result must be strings")
+
+ observation["task"] = task_result
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]
return observation
+
+
+def _close_single_env(env: Any) -> None:
+ try:
+ env.close()
+ except Exception as exc:
+ print(f"Exception while closing env {env}: {exc}")
+
+
+@singledispatch
+def close_envs(obj: Any) -> None:
+ """Default: raise if the type is not recognized."""
+ raise NotImplementedError(f"close_envs not implemented for type {type(obj).__name__}")
+
+
+@close_envs.register
+def _(env: Mapping) -> None:
+ for v in env.values():
+ if isinstance(v, Mapping):
+ close_envs(v)
+ elif hasattr(v, "close"):
+ _close_single_env(v)
+
+
+@close_envs.register
+def _(envs: Sequence) -> None:
+ if isinstance(envs, (str | bytes)):
+ return
+ for v in envs:
+ if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
+ close_envs(v)
+ elif hasattr(v, "close"):
+ _close_single_env(v)
+
+
+@close_envs.register
+def _(env: gym.Env) -> None:
+ _close_single_env(env)
diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py
index dfbfbaee8..850ef33d7 100644
--- a/src/lerobot/motors/__init__.py
+++ b/src/lerobot/motors/__init__.py
@@ -1 +1,17 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py
new file mode 100644
index 000000000..9832a1636
--- /dev/null
+++ b/src/lerobot/motors/calibration_gui.py
@@ -0,0 +1,401 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import os
+from dataclasses import dataclass
+
+os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"
+
+from lerobot.motors import MotorCalibration, MotorsBus
+
+BAR_LEN, BAR_THICKNESS = 450, 8
+HANDLE_R = 10
+BRACKET_W, BRACKET_H = 6, 14
+TRI_W, TRI_H = 12, 14
+
+BTN_W, BTN_H = 60, 22
+SAVE_W, SAVE_H = 80, 28
+LOAD_W = 80
+DD_W, DD_H = 160, 28
+
+TOP_GAP = 50
+PADDING_Y, TOP_OFFSET = 70, 60
+FONT_SIZE, FPS = 20, 60
+
+BG_COLOR = (30, 30, 30)
+BAR_RED, BAR_GREEN = (200, 60, 60), (60, 200, 60)
+HANDLE_COLOR, TEXT_COLOR = (240, 240, 240), (250, 250, 250)
+TICK_COLOR = (250, 220, 40)
+BTN_COLOR, BTN_COLOR_HL = (80, 80, 80), (110, 110, 110)
+DD_COLOR, DD_COLOR_HL = (70, 70, 70), (100, 100, 100)
+
+
+def dist(a, b):
+ return math.hypot(a[0] - b[0], a[1] - b[1])
+
+
+@dataclass
+class RangeValues:
+ min_v: int
+ pos_v: int
+ max_v: int
+
+
+class RangeSlider:
+ """One motor = one slider row"""
+
+ def __init__(self, motor, idx, res, calibration, present, label_pad, base_y):
+ import pygame
+
+ self.motor = motor
+ self.res = res
+ self.x0 = 40 + label_pad
+ self.x1 = self.x0 + BAR_LEN
+ self.y = base_y + idx * PADDING_Y
+
+ self.min_v = calibration.range_min
+ self.max_v = calibration.range_max
+ self.pos_v = max(self.min_v, min(present, self.max_v))
+
+ self.min_x = self._pos_from_val(self.min_v)
+ self.max_x = self._pos_from_val(self.max_v)
+ self.pos_x = self._pos_from_val(self.pos_v)
+
+ self.min_btn = pygame.Rect(self.x0 - BTN_W - 6, self.y - BTN_H // 2, BTN_W, BTN_H)
+ self.max_btn = pygame.Rect(self.x1 + 6, self.y - BTN_H // 2, BTN_W, BTN_H)
+
+ self.drag_min = self.drag_max = self.drag_pos = False
+ self.tick_val = present
+ self.font = pygame.font.Font(None, FONT_SIZE)
+
+ def _val_from_pos(self, x):
+ return round((x - self.x0) / BAR_LEN * self.res)
+
+ def _pos_from_val(self, v):
+ return self.x0 + (v / self.res) * BAR_LEN
+
+ def set_tick(self, v):
+ self.tick_val = max(0, min(v, self.res))
+
+ def _triangle_hit(self, pos):
+ import pygame
+
+ tri_top = self.y - BAR_THICKNESS // 2 - 2
+ return pygame.Rect(self.pos_x - TRI_W // 2, tri_top - TRI_H, TRI_W, TRI_H).collidepoint(pos)
+
+ def handle_event(self, e):
+ import pygame
+
+ if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
+ if self.min_btn.collidepoint(e.pos):
+ self.min_x, self.min_v = self.pos_x, self.pos_v
+ return
+ if self.max_btn.collidepoint(e.pos):
+ self.max_x, self.max_v = self.pos_x, self.pos_v
+ return
+ if dist(e.pos, (self.min_x, self.y)) <= HANDLE_R:
+ self.drag_min = True
+ elif dist(e.pos, (self.max_x, self.y)) <= HANDLE_R:
+ self.drag_max = True
+ elif self._triangle_hit(e.pos):
+ self.drag_pos = True
+
+ elif e.type == pygame.MOUSEBUTTONUP and e.button == 1:
+ self.drag_min = self.drag_max = self.drag_pos = False
+
+ elif e.type == pygame.MOUSEMOTION:
+ x = e.pos[0]
+ if self.drag_min:
+ self.min_x = max(self.x0, min(x, self.pos_x))
+ elif self.drag_max:
+ self.max_x = min(self.x1, max(x, self.pos_x))
+ elif self.drag_pos:
+ self.pos_x = max(self.min_x, min(x, self.max_x))
+
+ self.min_v = self._val_from_pos(self.min_x)
+ self.max_v = self._val_from_pos(self.max_x)
+ self.pos_v = self._val_from_pos(self.pos_x)
+
+ def _draw_button(self, surf, rect, text):
+ import pygame
+
+ clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR
+ pygame.draw.rect(surf, clr, rect, border_radius=4)
+ t = self.font.render(text, True, TEXT_COLOR)
+ surf.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2))
+
+ def draw(self, surf):
+ import pygame
+
+ # motor name above set-min button (right-aligned)
+ name_surf = self.font.render(self.motor, True, TEXT_COLOR)
+ surf.blit(
+ name_surf,
+ (self.min_btn.right - name_surf.get_width(), self.min_btn.y - name_surf.get_height() - 4),
+ )
+
+ # bar + active section
+ pygame.draw.rect(surf, BAR_RED, (self.x0, self.y - BAR_THICKNESS // 2, BAR_LEN, BAR_THICKNESS))
+ pygame.draw.rect(
+ surf, BAR_GREEN, (self.min_x, self.y - BAR_THICKNESS // 2, self.max_x - self.min_x, BAR_THICKNESS)
+ )
+
+ # tick
+ tick_x = self._pos_from_val(self.tick_val)
+ pygame.draw.line(
+ surf,
+ TICK_COLOR,
+ (tick_x, self.y - BAR_THICKNESS // 2 - 4),
+ (tick_x, self.y + BAR_THICKNESS // 2 + 4),
+ 2,
+ )
+
+ # brackets
+ for x, sign in ((self.min_x, +1), (self.max_x, -1)):
+ pygame.draw.line(
+ surf, HANDLE_COLOR, (x, self.y - BRACKET_H // 2), (x, self.y + BRACKET_H // 2), 2
+ )
+ pygame.draw.line(
+ surf,
+ HANDLE_COLOR,
+ (x, self.y - BRACKET_H // 2),
+ (x + sign * BRACKET_W, self.y - BRACKET_H // 2),
+ 2,
+ )
+ pygame.draw.line(
+ surf,
+ HANDLE_COLOR,
+ (x, self.y + BRACKET_H // 2),
+ (x + sign * BRACKET_W, self.y + BRACKET_H // 2),
+ 2,
+ )
+
+ # triangle ▼
+ tri_top = self.y - BAR_THICKNESS // 2 - 2
+ pygame.draw.polygon(
+ surf,
+ HANDLE_COLOR,
+ [
+ (self.pos_x, tri_top),
+ (self.pos_x - TRI_W // 2, tri_top - TRI_H),
+ (self.pos_x + TRI_W // 2, tri_top - TRI_H),
+ ],
+ )
+
+ # numeric labels
+ fh = self.font.get_height()
+ pos_y = tri_top - TRI_H - 4 - fh
+ txts = [
+ (self.min_v, self.min_x, self.y - BRACKET_H // 2 - 4 - fh),
+ (self.max_v, self.max_x, self.y - BRACKET_H // 2 - 4 - fh),
+ (self.pos_v, self.pos_x, pos_y),
+ ]
+ for v, x, y in txts:
+ s = self.font.render(str(v), True, TEXT_COLOR)
+ surf.blit(s, (x - s.get_width() // 2, y))
+
+ # buttons
+ self._draw_button(surf, self.min_btn, "set min")
+ self._draw_button(surf, self.max_btn, "set max")
+
+ # external
+ def values(self) -> RangeValues:
+ return RangeValues(self.min_v, self.pos_v, self.max_v)
+
+
+class RangeFinderGUI:
+ def __init__(self, bus: MotorsBus, groups: dict[str, list[str]] | None = None):
+ import pygame
+
+ self.bus = bus
+ self.groups = groups if groups is not None else {"all": list(bus.motors)}
+ self.group_names = list(groups)
+ self.current_group = self.group_names[0]
+
+ if not bus.is_connected:
+ bus.connect()
+
+ self.calibration = bus.read_calibration()
+ self.res_table = bus.model_resolution_table
+ self.present_cache = {
+ m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
+ }
+
+ pygame.init()
+ self.font = pygame.font.Font(None, FONT_SIZE)
+
+ label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
+ self.label_pad = label_pad
+ width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
+ self.controls_bottom = 10 + SAVE_H
+ self.base_y = self.controls_bottom + TOP_GAP
+ height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
+
+ self.screen = pygame.display.set_mode((width, height))
+ pygame.display.set_caption("Motors range finder")
+
+ # ui rects
+ self.save_btn = pygame.Rect(width - SAVE_W - 10, 10, SAVE_W, SAVE_H)
+ self.load_btn = pygame.Rect(self.save_btn.left - LOAD_W - 10, 10, LOAD_W, SAVE_H)
+ self.dd_btn = pygame.Rect(width // 2 - DD_W // 2, 10, DD_W, DD_H)
+ self.dd_open = False # dropdown expanded?
+
+ self.clock = pygame.time.Clock()
+ self._build_sliders()
+ self._adjust_height()
+
+ def _adjust_height(self):
+ import pygame
+
+ motors = self.groups[self.current_group]
+ new_h = self.base_y + PADDING_Y * len(motors) + 40
+ if new_h != self.screen.get_height():
+ w = self.screen.get_width()
+ self.screen = pygame.display.set_mode((w, new_h))
+
+ def _build_sliders(self):
+ self.sliders: list[RangeSlider] = []
+ motors = self.groups[self.current_group]
+ for i, m in enumerate(motors):
+ self.sliders.append(
+ RangeSlider(
+ motor=m,
+ idx=i,
+ res=self.res_table[self.bus.motors[m].model] - 1,
+ calibration=self.calibration[m],
+ present=self.present_cache[m],
+ label_pad=self.label_pad,
+ base_y=self.base_y,
+ )
+ )
+
+ def _draw_dropdown(self):
+ import pygame
+
+ # collapsed box
+ hover = self.dd_btn.collidepoint(pygame.mouse.get_pos())
+ pygame.draw.rect(self.screen, DD_COLOR_HL if hover else DD_COLOR, self.dd_btn, border_radius=6)
+
+ txt = self.font.render(self.current_group, True, TEXT_COLOR)
+ self.screen.blit(
+ txt, (self.dd_btn.centerx - txt.get_width() // 2, self.dd_btn.centery - txt.get_height() // 2)
+ )
+
+ tri_w, tri_h = 12, 6
+ cx = self.dd_btn.right - 14
+ cy = self.dd_btn.centery + 1
+ pygame.draw.polygon(
+ self.screen,
+ TEXT_COLOR,
+ [(cx - tri_w // 2, cy - tri_h // 2), (cx + tri_w // 2, cy - tri_h // 2), (cx, cy + tri_h // 2)],
+ )
+
+ if not self.dd_open:
+ return
+
+ # expanded list
+ for i, name in enumerate(self.group_names):
+ item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H)
+ clr = DD_COLOR_HL if item_rect.collidepoint(pygame.mouse.get_pos()) else DD_COLOR
+ pygame.draw.rect(self.screen, clr, item_rect)
+ t = self.font.render(name, True, TEXT_COLOR)
+ self.screen.blit(
+ t, (item_rect.centerx - t.get_width() // 2, item_rect.centery - t.get_height() // 2)
+ )
+
+ def _handle_dropdown_event(self, e):
+ import pygame
+
+ if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
+ if self.dd_btn.collidepoint(e.pos):
+ self.dd_open = not self.dd_open
+ return True
+ if self.dd_open:
+ for i, name in enumerate(self.group_names):
+ item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H)
+ if item_rect.collidepoint(e.pos):
+ if name != self.current_group:
+ self.current_group = name
+ self._build_sliders()
+ self._adjust_height()
+ self.dd_open = False
+ return True
+ self.dd_open = False
+ return False
+
+ def _save_current(self):
+ for s in self.sliders:
+ self.calibration[s.motor].range_min = s.min_v
+ self.calibration[s.motor].range_max = s.max_v
+
+ with self.bus.torque_disabled():
+ self.bus.write_calibration(self.calibration)
+
+ def _load_current(self):
+ self.calibration = self.bus.read_calibration()
+ for s in self.sliders:
+ s.min_v = self.calibration[s.motor].range_min
+ s.max_v = self.calibration[s.motor].range_max
+ s.min_x = s._pos_from_val(s.min_v)
+ s.max_x = s._pos_from_val(s.max_v)
+
+ def run(self) -> dict[str, MotorCalibration]:
+ import pygame
+
+ while True:
+ for e in pygame.event.get():
+ if e.type == pygame.QUIT:
+ pygame.quit()
+ return self.calibration
+
+ if self._handle_dropdown_event(e):
+ continue
+
+ if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1:
+ if self.save_btn.collidepoint(e.pos):
+ self._save_current()
+ elif self.load_btn.collidepoint(e.pos):
+ self._load_current()
+
+ for s in self.sliders:
+ s.handle_event(e)
+
+ # live goal write while dragging
+ for s in self.sliders:
+ if s.drag_pos:
+ self.bus.write("Goal_Position", s.motor, s.pos_v, normalize=False)
+
+ # tick update
+ for s in self.sliders:
+ pos = self.bus.read("Present_Position", s.motor, normalize=False)
+ s.set_tick(pos)
+ self.present_cache[s.motor] = pos
+
+ # ─ drawing
+ self.screen.fill(BG_COLOR)
+ for s in self.sliders:
+ s.draw(self.screen)
+
+ self._draw_dropdown()
+
+ # load / save buttons
+ for rect, text in ((self.load_btn, "LOAD"), (self.save_btn, "SAVE")):
+ clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR
+ pygame.draw.rect(self.screen, clr, rect, border_radius=6)
+ t = self.font.render(text, True, TEXT_COLOR)
+ self.screen.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2))
+
+ pygame.display.flip()
+ self.clock.tick(FPS)
diff --git a/src/lerobot/motors/dynamixel/__init__.py b/src/lerobot/motors/dynamixel/__init__.py
index 3e414557e..425f8538a 100644
--- a/src/lerobot/motors/dynamixel/__init__.py
+++ b/src/lerobot/motors/dynamixel/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
from .tables import *
diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py
index d4f41643c..e1d4e0963 100644
--- a/src/lerobot/motors/dynamixel/dynamixel.py
+++ b/src/lerobot/motors/dynamixel/dynamixel.py
@@ -22,7 +22,7 @@ import logging
from copy import deepcopy
from enum import Enum
-from lerobot.utils.encoding_utils import decode_twos_complement, encode_twos_complement
+from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
@@ -162,11 +162,11 @@ class DynamixelMotorsBus(MotorsBus):
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
- def configure_motors(self) -> None:
+ def configure_motors(self, return_delay_time=0) -> None:
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
for motor in self.motors:
- self.write("Return_Delay_Time", motor, 0)
+ self.write("Return_Delay_Time", motor, return_delay_time)
@property
def is_calibrated(self) -> bool:
@@ -190,13 +190,14 @@ class DynamixelMotorsBus(MotorsBus):
return calibration
- def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
+ def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
for motor, calibration in calibration_dict.items():
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
- self.calibration = calibration_dict
+ if cache:
+ self.calibration = calibration_dict
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
diff --git a/src/lerobot/motors/dynamixel/tables.py b/src/lerobot/motors/dynamixel/tables.py
index 8b67bbf38..5417d8cee 100644
--- a/src/lerobot/motors/dynamixel/tables.py
+++ b/src/lerobot/motors/dynamixel/tables.py
@@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = {
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
+ "Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1],
+ "Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
diff --git a/src/lerobot/utils/encoding_utils.py b/src/lerobot/motors/encoding_utils.py
similarity index 100%
rename from src/lerobot/utils/encoding_utils.py
rename to src/lerobot/motors/encoding_utils.py
diff --git a/src/lerobot/motors/feetech/__init__.py b/src/lerobot/motors/feetech/__init__.py
index 911d1d19f..75da2d221 100644
--- a/src/lerobot/motors/feetech/__init__.py
+++ b/src/lerobot/motors/feetech/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
from .tables import *
diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py
index 7edf869a4..2ea57af12 100644
--- a/src/lerobot/motors/feetech/feetech.py
+++ b/src/lerobot/motors/feetech/feetech.py
@@ -17,7 +17,7 @@ from copy import deepcopy
from enum import Enum
from pprint import pformat
-from lerobot.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
+from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
@@ -219,15 +219,15 @@ class FeetechMotorsBus(MotorsBus):
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
- def configure_motors(self) -> None:
+ def configure_motors(self, return_delay_time=0, maximum_acceleration=254, acceleration=254) -> None:
for motor in self.motors:
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
- self.write("Return_Delay_Time", motor, 0)
+ self.write("Return_Delay_Time", motor, return_delay_time)
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
- # Note: this address is not in the official STS3215 Memory Table
- self.write("Maximum_Acceleration", motor, 254)
- self.write("Acceleration", motor, 254)
+ if self.protocol_version == 0:
+ self.write("Maximum_Acceleration", motor, maximum_acceleration)
+ self.write("Acceleration", motor, acceleration)
@property
def is_calibrated(self) -> bool:
@@ -270,14 +270,15 @@ class FeetechMotorsBus(MotorsBus):
return calibration
- def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
+ def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
for motor, calibration in calibration_dict.items():
if self.protocol_version == 0:
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
- self.calibration = calibration_dict
+ if cache:
+ self.calibration = calibration_dict
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
"""
diff --git a/src/lerobot/motors/feetech/tables.py b/src/lerobot/motors/feetech/tables.py
index 0a2f2659f..48814957f 100644
--- a/src/lerobot/motors/feetech/tables.py
+++ b/src/lerobot/motors/feetech/tables.py
@@ -189,7 +189,7 @@ MODEL_RESOLUTION = {
"scs_series": 1024,
"sts3215": 4096,
"sts3250": 4096,
- "sm8512bl": 65536,
+ "sm8512bl": 4096,
"scs0009": 1024,
}
diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py
index 7386bfb1c..17eaa8063 100644
--- a/src/lerobot/motors/motors_bus.py
+++ b/src/lerobot/motors/motors_bus.py
@@ -32,7 +32,7 @@ import serial
from deepdiff import DeepDiff
from tqdm import tqdm
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int
@@ -99,12 +99,6 @@ class Motor:
norm_mode: MotorNormMode
-class JointOutOfRangeError(Exception):
- def __init__(self, message="Joint is out of range"):
- self.message = message
- super().__init__(self.message)
-
-
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
@@ -222,9 +216,9 @@ class MotorsBus(abc.ABC):
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
```bash
- python -m lerobot.find_port.py
+ lerobot-find-port.py
>>> Finding all available ports for the MotorsBus.
- >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
+ >>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"]
>>> Remove the usb cable from your MotorsBus and press Enter when done.
>>> The port of this MotorsBus is /dev/tty.usbmodem575E0031751.
>>> Reconnect the usb cable.
@@ -348,7 +342,7 @@ class MotorsBus(abc.ABC):
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
- if isinstance(values, (int, float)):
+ if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()}
@@ -446,7 +440,7 @@ class MotorsBus(abc.ABC):
except (FileNotFoundError, OSError, serial.SerialException) as e:
raise ConnectionError(
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
- "\nTry running `python -m lerobot.find_port`\n"
+ "\nTry running `lerobot-find-port`\n"
) from e
@abc.abstractmethod
@@ -586,7 +580,7 @@ class MotorsBus(abc.ABC):
pass
@contextmanager
- def torque_disabled(self):
+ def torque_disabled(self, motors: int | str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -596,11 +590,11 @@ class MotorsBus(abc.ABC):
... # Safe operations here
... pass
"""
- self.disable_torque()
+ self.disable_torque(motors)
try:
yield
finally:
- self.enable_torque()
+ self.enable_torque(motors)
def set_timeout(self, timeout_ms: int | None = None):
"""Change the packet timeout used by the SDK.
@@ -653,12 +647,13 @@ class MotorsBus(abc.ABC):
pass
@abc.abstractmethod
- def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
- """Write calibration parameters to the motors and cache them.
+ def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None:
+ """Write calibration parameters to the motors and optionally cache them.
Args:
calibration_dict (dict[str, MotorCalibration]): Calibration obtained from
:pymeth:`read_calibration` or crafted by the user.
+ cache (bool, optional): Save the calibration to :pyattr:`calibration`. Defaults to True.
"""
pass
@@ -674,7 +669,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
- elif isinstance(motors, (str, int)):
+ elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -702,7 +697,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
- elif isinstance(motors, (str, int)):
+ elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -738,7 +733,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
- elif isinstance(motors, (str, int)):
+ elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py
index ece4dc157..f2bd0df42 100644
--- a/src/lerobot/optim/optimizers.py
+++ b/src/lerobot/optim/optimizers.py
@@ -22,11 +22,11 @@ import draccus
import torch
from safetensors.torch import load_file, save_file
-from lerobot.constants import (
+from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
+from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
-from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.utils.io_utils import deserialize_json_into_object
diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py
index d08018175..55ee62e40 100644
--- a/src/lerobot/optim/schedulers.py
+++ b/src/lerobot/optim/schedulers.py
@@ -22,8 +22,8 @@ import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
-from lerobot.constants import SCHEDULER_STATE
from lerobot.datasets.utils import write_json
+from lerobot.utils.constants import SCHEDULER_STATE
from lerobot.utils.io_utils import deserialize_json_into_object
diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py
index 9cb0f6234..49f1e0f95 100644
--- a/src/lerobot/policies/__init__.py
+++ b/src/lerobot/policies/__init__.py
@@ -15,6 +15,18 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
+from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
+from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
+
+__all__ = [
+ "ACTConfig",
+ "DiffusionConfig",
+ "PI0Config",
+ "PI05Config",
+ "SmolVLAConfig",
+ "TDMPCConfig",
+ "VQBeTConfig",
+]
diff --git a/src/lerobot/policies/act/README.md b/src/lerobot/policies/act/README.md
new file mode 120000
index 000000000..046020098
--- /dev/null
+++ b/src/lerobot/policies/act/README.md
@@ -0,0 +1 @@
+../../../../docs/source/policy_act_README.md
\ No newline at end of file
diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py
index ed911e9be..4d2890ba6 100644
--- a/src/lerobot/policies/act/modeling_act.py
+++ b/src/lerobot/policies/act/modeling_act.py
@@ -21,8 +21,8 @@ The majority of changes here involve removing unused code, unifying naming, and
import math
from collections import deque
+from collections.abc import Callable
from itertools import chain
-from typing import Callable
import einops
import numpy as np
@@ -33,10 +33,9 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
-from lerobot.constants import ACTION, OBS_IMAGES
from lerobot.policies.act.configuration_act import ACTConfig
-from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class ACTPolicy(PreTrainedPolicy):
@@ -51,27 +50,16 @@ class ACTPolicy(PreTrainedPolicy):
def __init__(
self,
config: ACTConfig,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
- dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
- that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
-
self.model = ACT(config)
if config.temporal_ensemble_coeff is not None:
@@ -107,7 +95,7 @@ class ACTPolicy(PreTrainedPolicy):
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
- @torch.no_grad
+ @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@@ -132,28 +120,24 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
- @torch.no_grad
+ @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
- batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
actions = self.model(batch)[0]
- actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
- batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
- batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
@@ -216,7 +200,7 @@ class ACTTemporalEnsembler:
continue
avg *= exp_weights[:i].sum()
avg += item * exp_weights[i]
- avg /= exp_weights[:i+1].sum()
+ avg /= exp_weights[: i + 1].sum()
print("online", avg)
```
"""
@@ -410,25 +394,22 @@ class ACT(nn.Module):
latent dimension.
"""
if self.config.use_vae and self.training:
- assert "action" in batch, (
+ assert ACTION in batch, (
"actions must be provided when using the variational objective in training mode."
)
- if "observation.images" in batch:
- batch_size = batch["observation.images"][0].shape[0]
- else:
- batch_size = batch["observation.environment_state"].shape[0]
+ batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0]
# Prepare the latent for input to the transformer encoder.
- if self.config.use_vae and "action" in batch:
+ if self.config.use_vae and ACTION in batch and self.training:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
- robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
+ robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
- action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
+ action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
@@ -446,7 +427,7 @@ class ACT(nn.Module):
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.config.robot_state_feature else 1),
False,
- device=batch["observation.state"].device,
+ device=batch[OBS_STATE].device,
)
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
@@ -470,7 +451,7 @@ class ACT(nn.Module):
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
- batch["observation.state"].device
+ batch[OBS_STATE].device
)
# Prepare transformer encoder inputs.
@@ -478,20 +459,16 @@ class ACT(nn.Module):
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.config.robot_state_feature:
- encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
+ encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE]))
# Environment state token.
if self.config.env_state_feature:
- encoder_in_tokens.append(
- self.encoder_env_state_input_proj(batch["observation.environment_state"])
- )
+ encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE]))
- # Camera observation features and positional embeddings.
if self.config.image_features:
- all_cam_features = []
- all_cam_pos_embeds = []
-
# For a list of images, the H and W may vary but H*W is constant.
- for img in batch["observation.images"]:
+ # NOTE: If modifying this section, verify on MPS devices that
+ # gradients remain stable (no explosions or NaNs).
+ for img in batch[OBS_IMAGES]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features)
@@ -500,11 +477,10 @@ class ACT(nn.Module):
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
- all_cam_features.append(cam_features)
- all_cam_pos_embeds.append(cam_pos_embed)
-
- encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
- encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
+ # Extend immediately instead of accumulating and concatenating
+ # Convert to list to extend properly
+ encoder_in_tokens.extend(list(cam_features))
+ encoder_in_pos_embed.extend(list(cam_pos_embed))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py
new file mode 100644
index 000000000..727b18cef
--- /dev/null
+++ b/src/lerobot/policies/act/processor_act.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any
+
+import torch
+
+from lerobot.policies.act.configuration_act import ACTConfig
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ RenameObservationsProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+def make_act_pre_post_processors(
+ config: ACTConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """Creates the pre- and post-processing pipelines for the ACT policy.
+
+ The pre-processing pipeline handles normalization, batching, and device placement for the model inputs.
+ The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU.
+
+ Args:
+ config (ACTConfig): The ACT policy configuration object.
+ dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset
+ statistics (e.g., mean and std) used for normalization. Defaults to None.
+
+ Returns:
+ tuple[PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction]]: A tuple containing the
+ pre-processor pipeline and the post-processor pipeline.
+ """
+
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}),
+ AddBatchDimensionProcessorStep(),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ device=config.device,
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/diffusion/README.md b/src/lerobot/policies/diffusion/README.md
new file mode 120000
index 000000000..d332d79c8
--- /dev/null
+++ b/src/lerobot/policies/diffusion/README.md
@@ -0,0 +1 @@
+../../../../docs/source/policy_diffusion_README.md
\ No newline at end of file
diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py
index ce2de7052..54569434a 100644
--- a/src/lerobot/policies/diffusion/configuration_diffusion.py
+++ b/src/lerobot/policies/diffusion/configuration_diffusion.py
@@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig):
)
# Check that all input images have the same shape.
- first_image_key, first_image_ft = next(iter(self.image_features.items()))
- for key, image_ft in self.image_features.items():
- if image_ft.shape != first_image_ft.shape:
- raise ValueError(
- f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
- )
+ if len(self.image_features) > 0:
+ first_image_key, first_image_ft = next(iter(self.image_features.items()))
+ for key, image_ft in self.image_features.items():
+ if image_ft.shape != first_image_ft.shape:
+ raise ValueError(
+ f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
+ )
@property
def observation_delta_indices(self) -> list:
diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py
index af40f7a86..3ab6719cb 100644
--- a/src/lerobot/policies/diffusion/modeling_diffusion.py
+++ b/src/lerobot/policies/diffusion/modeling_diffusion.py
@@ -22,7 +22,7 @@ TODO(alexander-soare):
import math
from collections import deque
-from typing import Callable
+from collections.abc import Callable
import einops
import numpy as np
@@ -33,9 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
-from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
-from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
get_device_from_parameters,
@@ -43,6 +41,7 @@ from lerobot.policies.utils import (
get_output_shape,
populate_queues,
)
+from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class DiffusionPolicy(PreTrainedPolicy):
@@ -57,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy):
def __init__(
self,
config: DiffusionConfig,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -70,14 +68,6 @@ class DiffusionPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
-
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -91,28 +81,25 @@ class DiffusionPolicy(PreTrainedPolicy):
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
- "observation.state": deque(maxlen=self.config.n_obs_steps),
- "action": deque(maxlen=self.config.n_action_steps),
+ OBS_STATE: deque(maxlen=self.config.n_obs_steps),
+ ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
- self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
+ self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
- self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
+ self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
- @torch.no_grad
- def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
+ @torch.no_grad()
+ def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
- actions = self.diffusion.generate_actions(batch)
-
- # TODO(rcadene): make above methods return output dictionary?
- actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
+ actions = self.diffusion.generate_actions(batch, noise=noise)
return actions
- @torch.no_grad
- def select_action(self, batch: dict[str, Tensor]) -> Tensor:
+ @torch.no_grad()
+ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
@@ -133,15 +120,18 @@ class DiffusionPolicy(PreTrainedPolicy):
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
- batch = self.normalize_inputs(batch)
+ # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
+ if ACTION in batch:
+ batch.pop(ACTION)
+
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
- # Note: It's important that this happens after stacking the images into a single key.
+ # NOTE: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
- actions = self.predict_action_chunk(batch)
+ actions = self.predict_action_chunk(batch, noise=noise)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
@@ -149,11 +139,9 @@ class DiffusionPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
- batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
- batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
return loss, None
@@ -211,17 +199,25 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
- self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
+ self,
+ batch_size: int,
+ global_cond: Tensor | None = None,
+ generator: torch.Generator | None = None,
+ noise: Tensor | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
- sample = torch.randn(
- size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
- dtype=dtype,
- device=device,
- generator=generator,
+ sample = (
+ noise
+ if noise is not None
+ else torch.randn(
+ size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
+ dtype=dtype,
+ device=device,
+ generator=generator,
+ )
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
@@ -246,7 +242,7 @@ class DiffusionModel(nn.Module):
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
- images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
+ images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
@@ -261,7 +257,7 @@ class DiffusionModel(nn.Module):
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
- einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
+ einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
@@ -276,7 +272,7 @@ class DiffusionModel(nn.Module):
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
- def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
+ def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""
This function expects `batch` to have:
{
@@ -284,17 +280,17 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
- "observation.environment_state": (B, environment_dim)
+ "observation.environment_state": (B, n_obs_steps, environment_dim)
}
"""
- batch_size, n_obs_steps = batch["observation.state"].shape[:2]
+ batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# run sampling
- actions = self.conditional_sample(batch_size, global_cond=global_cond)
+ actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
@@ -311,17 +307,17 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
- "observation.environment_state": (B, environment_dim)
+ "observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
- assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
- assert "observation.images" in batch or "observation.environment_state" in batch
- n_obs_steps = batch["observation.state"].shape[1]
- horizon = batch["action"].shape[1]
+ assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
+ assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
+ n_obs_steps = batch[OBS_STATE].shape[1]
+ horizon = batch[ACTION].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
@@ -329,7 +325,7 @@ class DiffusionModel(nn.Module):
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Forward diffusion.
- trajectory = batch["action"]
+ trajectory = batch[ACTION]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
@@ -350,7 +346,7 @@ class DiffusionModel(nn.Module):
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
- target = batch["action"]
+ target = batch[ACTION]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py
new file mode 100644
index 000000000..a7799be64
--- /dev/null
+++ b/src/lerobot/policies/diffusion/processor_diffusion.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any
+
+import torch
+
+from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ RenameObservationsProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+def make_diffusion_pre_post_processors(
+ config: DiffusionConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for a diffusion policy.
+
+ The pre-processing pipeline prepares the input data for the model by:
+ 1. Renaming features.
+ 2. Normalizing the input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Moving the data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving the data to the CPU.
+ 2. Unnormalizing the output features to their original scale.
+
+ Args:
+ config: The configuration object for the diffusion policy,
+ containing feature definitions, normalization mappings, and device information.
+ dataset_stats: A dictionary of statistics used for normalization.
+ Defaults to None.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}),
+ AddBatchDimensionProcessorStep(),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py
index ef56bdb61..ac76baf9f 100644
--- a/src/lerobot/policies/factory.py
+++ b/src/lerobot/policies/factory.py
@@ -14,9 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
+from __future__ import annotations
-from torch import nn
+import logging
+from typing import Any, TypedDict
+
+import torch
+from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
@@ -28,16 +32,40 @@ from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
+from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
+from lerobot.processor import PolicyAction, PolicyProcessorPipeline
+from lerobot.processor.converters import (
+ batch_to_transition,
+ policy_action_to_transition,
+ transition_to_batch,
+ transition_to_policy_action,
+)
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
-def get_policy_class(name: str) -> PreTrainedPolicy:
- """Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
+def get_policy_class(name: str) -> type[PreTrainedPolicy]:
+ """
+ Retrieves a policy class by its registered name.
+
+ This function uses dynamic imports to avoid loading all policy classes into memory
+ at once, improving startup time and reducing dependencies.
+
+ Args:
+ name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
+ "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
+
+ Returns:
+ The policy class corresponding to the given name.
+
+ Raises:
+ NotImplementedError: If the policy name is not recognized.
+ """
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@@ -54,14 +82,18 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
- elif name == "pi0":
- from lerobot.policies.pi0.modeling_pi0 import PI0Policy
-
- return PI0Policy
elif name == "pi0fast":
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
+ elif name == "pi0":
+ from lerobot.policies.pi0.modeling_pi0 import PI0Policy
+
+ return PI0Policy
+ elif name == "pi05":
+ from lerobot.policies.pi05.modeling_pi05 import PI05Policy
+
+ return PI05Policy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -79,6 +111,24 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
+ """
+ Instantiates a policy configuration object based on the policy type.
+
+ This factory function simplifies the creation of policy configuration objects by
+ mapping a string identifier to the corresponding config class.
+
+ Args:
+ policy_type: The type of the policy. Supported types include "tdmpc",
+ "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
+ "reward_classifier".
+ **kwargs: Keyword arguments to be passed to the configuration class constructor.
+
+ Returns:
+ An instance of a `PreTrainedConfig` subclass.
+
+ Raises:
+ ValueError: If the `policy_type` is not recognized.
+ """
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
@@ -87,10 +137,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
- elif policy_type == "pi0":
- return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
+ elif policy_type == "pi0":
+ return PI0Config(**kwargs)
+ elif policy_type == "pi05":
+ return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
@@ -101,30 +153,195 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
raise ValueError(f"Policy type '{policy_type}' is not available.")
+class ProcessorConfigKwargs(TypedDict, total=False):
+ """
+ A TypedDict defining the keyword arguments for processor configuration.
+
+ This provides type hints for the optional arguments passed to `make_pre_post_processors`,
+ improving code clarity and enabling static analysis.
+
+ Attributes:
+ preprocessor_config_filename: The filename for the preprocessor configuration.
+ postprocessor_config_filename: The filename for the postprocessor configuration.
+ preprocessor_overrides: A dictionary of overrides for the preprocessor configuration.
+ postprocessor_overrides: A dictionary of overrides for the postprocessor configuration.
+ dataset_stats: Dataset statistics for normalization.
+ """
+
+ preprocessor_config_filename: str | None
+ postprocessor_config_filename: str | None
+ preprocessor_overrides: dict[str, Any] | None
+ postprocessor_overrides: dict[str, Any] | None
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None
+
+
+def make_pre_post_processors(
+ policy_cfg: PreTrainedConfig,
+ pretrained_path: str | None = None,
+ **kwargs: Unpack[ProcessorConfigKwargs],
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Create or load pre- and post-processor pipelines for a given policy.
+
+ This function acts as a factory. It can either load existing processor pipelines
+ from a pretrained path or create new ones from scratch based on the policy
+ configuration. Each policy type has a dedicated factory function for its
+ processors (e.g., `make_tdmpc_pre_post_processors`).
+
+ Args:
+ policy_cfg: The configuration of the policy for which to create processors.
+ pretrained_path: An optional path to load pretrained processor pipelines from.
+ If provided, pipelines are loaded from this path.
+ **kwargs: Keyword arguments for processor configuration, as defined in
+ `ProcessorConfigKwargs`.
+
+ Returns:
+ A tuple containing the input (pre-processor) and output (post-processor) pipelines.
+
+ Raises:
+ NotImplementedError: If a processor factory is not implemented for the given
+ policy configuration type.
+ """
+ if pretrained_path:
+ return (
+ PolicyProcessorPipeline.from_pretrained(
+ pretrained_model_name_or_path=pretrained_path,
+ config_filename=kwargs.get(
+ "preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
+ ),
+ overrides=kwargs.get("preprocessor_overrides", {}),
+ to_transition=batch_to_transition,
+ to_output=transition_to_batch,
+ ),
+ PolicyProcessorPipeline.from_pretrained(
+ pretrained_model_name_or_path=pretrained_path,
+ config_filename=kwargs.get(
+ "postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
+ ),
+ overrides=kwargs.get("postprocessor_overrides", {}),
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
+
+ # Create a new processor based on policy type
+ if isinstance(policy_cfg, TDMPCConfig):
+ from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
+
+ processors = make_tdmpc_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, DiffusionConfig):
+ from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
+
+ processors = make_diffusion_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, ACTConfig):
+ from lerobot.policies.act.processor_act import make_act_pre_post_processors
+
+ processors = make_act_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, VQBeTConfig):
+ from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
+
+ processors = make_vqbet_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, PI0FASTConfig):
+ from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
+
+ processors = make_pi0fast_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, PI0Config):
+ from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
+
+ processors = make_pi0_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, PI05Config):
+ from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
+
+ processors = make_pi05_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, SACConfig):
+ from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
+
+ processors = make_sac_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, RewardClassifierConfig):
+ from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
+
+ processors = make_classifier_processor(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ elif isinstance(policy_cfg, SmolVLAConfig):
+ from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
+
+ processors = make_smolvla_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
+
+ else:
+ raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
+
+ return processors
+
+
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
- """Make an instance of a policy class.
+ """
+ Instantiate a policy model.
- This function exists because (for now) we need to parse features from either a dataset or an environment
- in order to properly dimension and instantiate a policy for that dataset or environment.
+ This factory function handles the logic of creating a policy, which requires
+ determining the input and output feature shapes. These shapes can be derived
+ either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function
+ can either initialize a new policy from scratch or load a pretrained one.
Args:
- cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
- be loaded with the weights from that path.
- ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
- statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
- env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
- provided if ds_meta is not. Defaults to None.
-
- Raises:
- ValueError: Either ds_meta or env and env_cfg must be provided.
- NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
+ cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is
+ set, the policy will be loaded with weights from that path.
+ ds_meta: Dataset metadata used to infer feature shapes and types. Also provides
+ statistics for normalization layers.
+ env_cfg: Environment configuration used to infer feature shapes and types.
+ One of `ds_meta` or `env_cfg` must be provided.
Returns:
- PreTrainedPolicy: _description_
+ An instantiated and device-placed policy model.
+
+ Raises:
+ ValueError: If both or neither of `ds_meta` and `env_cfg` are provided.
+ NotImplementedError: If attempting to use an unsupported policy-backend
+ combination (e.g., VQBeT with 'mps').
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
@@ -147,7 +364,6 @@ def make_policy(
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
- kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
@@ -155,6 +371,8 @@ def make_policy(
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
+ if env_cfg is None:
+ raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
@@ -171,7 +389,7 @@ def make_policy(
policy = policy_cls(**kwargs)
policy.to(cfg.device)
- assert isinstance(policy, nn.Module)
+ assert isinstance(policy, torch.nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py
deleted file mode 100644
index 9cc94b929..000000000
--- a/src/lerobot/policies/normalize.py
+++ /dev/null
@@ -1,420 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import numpy as np
-import torch
-from torch import Tensor, nn
-
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-
-
-def create_stats_buffers(
- features: dict[str, PolicyFeature],
- norm_map: dict[str, NormalizationMode],
- stats: dict[str, dict[str, Tensor]] | None = None,
-) -> dict[str, dict[str, nn.ParameterDict]]:
- """
- Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
- statistics.
-
- Args: (see Normalize and Unnormalize)
-
- Returns:
- dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
- `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
- """
- stats_buffers = {}
-
- for key, ft in features.items():
- norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
- if norm_mode is NormalizationMode.IDENTITY:
- continue
-
- assert isinstance(norm_mode, NormalizationMode)
-
- shape = tuple(ft.shape)
-
- if ft.type is FeatureType.VISUAL:
- # sanity checks
- assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
- c, h, w = shape
- assert c < h and c < w, f"{key} is not channel first ({shape=})"
- # override image shape to be invariant to height and width
- shape = (c, 1, 1)
-
- # Note: we initialize mean, std, min, max to infinity. They should be overwritten
- # downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
- # we assert they are not infinity anymore.
-
- buffer = {}
- if norm_mode is NormalizationMode.MEAN_STD:
- mean = torch.ones(shape, dtype=torch.float32) * torch.inf
- std = torch.ones(shape, dtype=torch.float32) * torch.inf
- buffer = nn.ParameterDict(
- {
- "mean": nn.Parameter(mean, requires_grad=False),
- "std": nn.Parameter(std, requires_grad=False),
- }
- )
- elif norm_mode is NormalizationMode.MIN_MAX:
- min = torch.ones(shape, dtype=torch.float32) * torch.inf
- max = torch.ones(shape, dtype=torch.float32) * torch.inf
- buffer = nn.ParameterDict(
- {
- "min": nn.Parameter(min, requires_grad=False),
- "max": nn.Parameter(max, requires_grad=False),
- }
- )
-
- # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
- if stats:
- if isinstance(stats[key]["mean"], np.ndarray):
- if norm_mode is NormalizationMode.MEAN_STD:
- buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
- buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
- elif norm_mode is NormalizationMode.MIN_MAX:
- buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
- buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
- elif isinstance(stats[key]["mean"], torch.Tensor):
- # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
- # tensors anywhere (for example, when we use the same stats for normalization and
- # unnormalization). See the logic here
- # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
- if norm_mode is NormalizationMode.MEAN_STD:
- buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
- buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
- elif norm_mode is NormalizationMode.MIN_MAX:
- buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
- buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
- else:
- type_ = type(stats[key]["mean"])
- raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
-
- stats_buffers[key] = buffer
- return stats_buffers
-
-
-def _no_stats_error_str(name: str) -> str:
- return (
- f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
- "pretrained model."
- )
-
-
-class Normalize(nn.Module):
- """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
-
- def __init__(
- self,
- features: dict[str, PolicyFeature],
- norm_map: dict[str, NormalizationMode],
- stats: dict[str, dict[str, Tensor]] | None = None,
- ):
- """
- Args:
- shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
- are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
- mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
- is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
- modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
- are their normalization modes among:
- - "mean_std": subtract the mean and divide by standard deviation.
- - "min_max": map to [-1, 1] range.
- stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
- and values are dictionaries of statistic types and their values (e.g.
- `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
- training the model for the first time, these statistics will overwrite the default buffers. If
- not provided, as expected for finetuning or evaluation, the default buffers should to be
- overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
- dataset is not needed to get the stats, since they are already in the policy state_dict.
- """
- super().__init__()
- self.features = features
- self.norm_map = norm_map
- self.stats = stats
- stats_buffers = create_stats_buffers(features, norm_map, stats)
- for key, buffer in stats_buffers.items():
- setattr(self, "buffer_" + key.replace(".", "_"), buffer)
-
- # TODO(rcadene): should we remove torch.no_grad?
- @torch.no_grad
- def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
- # TODO: Remove this shallow copy
- batch = dict(batch) # shallow copy avoids mutating the input batch
- for key, ft in self.features.items():
- if key not in batch:
- # FIXME(aliberts, rcadene): This might lead to silent fail!
- continue
-
- norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
- if norm_mode is NormalizationMode.IDENTITY:
- continue
-
- buffer = getattr(self, "buffer_" + key.replace(".", "_"))
-
- if norm_mode is NormalizationMode.MEAN_STD:
- mean = buffer["mean"]
- std = buffer["std"]
- assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
- assert not torch.isinf(std).any(), _no_stats_error_str("std")
- batch[key] = (batch[key] - mean) / (std + 1e-8)
- elif norm_mode is NormalizationMode.MIN_MAX:
- min = buffer["min"]
- max = buffer["max"]
- assert not torch.isinf(min).any(), _no_stats_error_str("min")
- assert not torch.isinf(max).any(), _no_stats_error_str("max")
- # normalize to [0,1]
- batch[key] = (batch[key] - min) / (max - min + 1e-8)
- # normalize to [-1, 1]
- batch[key] = batch[key] * 2 - 1
- else:
- raise ValueError(norm_mode)
- return batch
-
-
-class Unnormalize(nn.Module):
- """
- Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
- original range used by the environment.
- """
-
- def __init__(
- self,
- features: dict[str, PolicyFeature],
- norm_map: dict[str, NormalizationMode],
- stats: dict[str, dict[str, Tensor]] | None = None,
- ):
- """
- Args:
- shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
- are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
- mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
- is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
- modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
- are their normalization modes among:
- - "mean_std": subtract the mean and divide by standard deviation.
- - "min_max": map to [-1, 1] range.
- stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
- and values are dictionaries of statistic types and their values (e.g.
- `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
- training the model for the first time, these statistics will overwrite the default buffers. If
- not provided, as expected for finetuning or evaluation, the default buffers should to be
- overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
- dataset is not needed to get the stats, since they are already in the policy state_dict.
- """
- super().__init__()
- self.features = features
- self.norm_map = norm_map
- self.stats = stats
- # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
- stats_buffers = create_stats_buffers(features, norm_map, stats)
- for key, buffer in stats_buffers.items():
- setattr(self, "buffer_" + key.replace(".", "_"), buffer)
-
- # TODO(rcadene): should we remove torch.no_grad?
- @torch.no_grad
- def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
- batch = dict(batch) # shallow copy avoids mutating the input batch
- for key, ft in self.features.items():
- if key not in batch:
- continue
-
- norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
- if norm_mode is NormalizationMode.IDENTITY:
- continue
-
- buffer = getattr(self, "buffer_" + key.replace(".", "_"))
-
- if norm_mode is NormalizationMode.MEAN_STD:
- mean = buffer["mean"]
- std = buffer["std"]
- assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
- assert not torch.isinf(std).any(), _no_stats_error_str("std")
- batch[key] = batch[key] * std + mean
- elif norm_mode is NormalizationMode.MIN_MAX:
- min = buffer["min"]
- max = buffer["max"]
- assert not torch.isinf(min).any(), _no_stats_error_str("min")
- assert not torch.isinf(max).any(), _no_stats_error_str("max")
- batch[key] = (batch[key] + 1) / 2
- batch[key] = batch[key] * (max - min) + min
- else:
- raise ValueError(norm_mode)
- return batch
-
-
-# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
-# and remove the `Normalize` and `Unnormalize` classes.
-def _initialize_stats_buffers(
- module: nn.Module,
- features: dict[str, PolicyFeature],
- norm_map: dict[str, NormalizationMode],
- stats: dict[str, dict[str, Tensor]] | None = None,
-) -> None:
- """Register statistics buffers (mean/std or min/max) on the given *module*.
-
- The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
- but is factored out so it can be reused by both classes and stay in sync.
- """
- for key, ft in features.items():
- norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
- if norm_mode is NormalizationMode.IDENTITY:
- continue
-
- shape: tuple[int, ...] = tuple(ft.shape)
- if ft.type is FeatureType.VISUAL:
- # reduce spatial dimensions, keep channel dimension only
- c, *_ = shape
- shape = (c, 1, 1)
-
- prefix = key.replace(".", "_")
-
- if norm_mode is NormalizationMode.MEAN_STD:
- mean = torch.full(shape, torch.inf, dtype=torch.float32)
- std = torch.full(shape, torch.inf, dtype=torch.float32)
-
- if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
- mean_data = stats[key]["mean"]
- std_data = stats[key]["std"]
- if isinstance(mean_data, torch.Tensor):
- # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
- # tensors anywhere (for example, when we use the same stats for normalization and
- # unnormalization). See the logic here
- # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
- mean = mean_data.clone().to(dtype=torch.float32)
- std = std_data.clone().to(dtype=torch.float32)
- else:
- raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
-
- module.register_buffer(f"{prefix}_mean", mean)
- module.register_buffer(f"{prefix}_std", std)
- continue
-
- if norm_mode is NormalizationMode.MIN_MAX:
- min_val = torch.full(shape, torch.inf, dtype=torch.float32)
- max_val = torch.full(shape, torch.inf, dtype=torch.float32)
-
- if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
- min_data = stats[key]["min"]
- max_data = stats[key]["max"]
- if isinstance(min_data, torch.Tensor):
- min_val = min_data.clone().to(dtype=torch.float32)
- max_val = max_data.clone().to(dtype=torch.float32)
- else:
- raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
-
- module.register_buffer(f"{prefix}_min", min_val)
- module.register_buffer(f"{prefix}_max", max_val)
- continue
-
- raise ValueError(norm_mode)
-
-
-class NormalizeBuffer(nn.Module):
- """Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
-
- def __init__(
- self,
- features: dict[str, PolicyFeature],
- norm_map: dict[str, NormalizationMode],
- stats: dict[str, dict[str, Tensor]] | None = None,
- ):
- super().__init__()
- self.features = features
- self.norm_map = norm_map
-
- _initialize_stats_buffers(self, features, norm_map, stats)
-
- def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
- batch = dict(batch)
- for key, ft in self.features.items():
- if key not in batch:
- continue
-
- norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
- if norm_mode is NormalizationMode.IDENTITY:
- continue
-
- prefix = key.replace(".", "_")
-
- if norm_mode is NormalizationMode.MEAN_STD:
- mean = getattr(self, f"{prefix}_mean")
- std = getattr(self, f"{prefix}_std")
- assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
- assert not torch.isinf(std).any(), _no_stats_error_str("std")
- batch[key] = (batch[key] - mean) / (std + 1e-8)
- continue
-
- if norm_mode is NormalizationMode.MIN_MAX:
- min_val = getattr(self, f"{prefix}_min")
- max_val = getattr(self, f"{prefix}_max")
- assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
- assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
- batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
- batch[key] = batch[key] * 2 - 1
- continue
-
- raise ValueError(norm_mode)
-
- return batch
-
-
-class UnnormalizeBuffer(nn.Module):
- """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
-
- def __init__(
- self,
- features: dict[str, PolicyFeature],
- norm_map: dict[str, NormalizationMode],
- stats: dict[str, dict[str, Tensor]] | None = None,
- ):
- super().__init__()
- self.features = features
- self.norm_map = norm_map
-
- _initialize_stats_buffers(self, features, norm_map, stats)
-
- def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
- # batch = dict(batch)
- for key, ft in self.features.items():
- if key not in batch:
- continue
-
- norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
- if norm_mode is NormalizationMode.IDENTITY:
- continue
-
- prefix = key.replace(".", "_")
-
- if norm_mode is NormalizationMode.MEAN_STD:
- mean = getattr(self, f"{prefix}_mean")
- std = getattr(self, f"{prefix}_std")
- assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
- assert not torch.isinf(std).any(), _no_stats_error_str("std")
- batch[key] = batch[key] * std + mean
- continue
-
- if norm_mode is NormalizationMode.MIN_MAX:
- min_val = getattr(self, f"{prefix}_min")
- max_val = getattr(self, f"{prefix}_max")
- assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
- assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
- batch[key] = (batch[key] + 1) / 2
- batch[key] = batch[key] * (max_val - min_val) + min_val
- continue
-
- raise ValueError(norm_mode)
-
- return batch
diff --git a/src/lerobot/policies/pi0/README.md b/src/lerobot/policies/pi0/README.md
new file mode 100644
index 000000000..65b331e51
--- /dev/null
+++ b/src/lerobot/policies/pi0/README.md
@@ -0,0 +1,49 @@
+# π₀ (pi0)
+
+This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
+It is designed as a **Vision-Language-Action model for general robot control**.
+
+---
+
+## Model Overview
+
+| Feature | π₀ | π₀.₅ |
+| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
+| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
+| AdaRMS | Not used | Used in action expert |
+| Tokenizer Length | 48 tokens | 200 tokens |
+| Discrete State Input | False (Uses `state_proj` layer) | True |
+| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
+
+---
+
+## Citation
+
+If you use this work, please cite both **OpenPI** and the π₀ paper:
+
+```bibtex
+@misc{openpi2024,
+ author = {Physical Intelligence Lab},
+ title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
+ year = {2024},
+ publisher = {GitHub},
+ howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
+ license = {Apache-2.0}
+}
+
+@misc{black2024pi0visionlanguageactionflowmodel,
+ title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
+ author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
+ year = {2024},
+ eprint = {2410.24164},
+ archivePrefix= {arXiv},
+ primaryClass = {cs.LG},
+ url = {https://arxiv.org/abs/2410.24164},
+}
+```
+
+---
+
+## License
+
+This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
diff --git a/src/lerobot/policies/pi0/__init__.py b/src/lerobot/policies/pi0/__init__.py
new file mode 100644
index 000000000..ea3095b4e
--- /dev/null
+++ b/src/lerobot/policies/pi0/__init__.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_pi0 import PI0Config
+from .modeling_pi0 import PI0Policy
+from .processor_pi0 import make_pi0_pre_post_processors
+
+__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"]
diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py
index c9728e418..cc1cda9d8 100644
--- a/src/lerobot/policies/pi0/configuration_pi0.py
+++ b/src/lerobot/policies/pi0/configuration_pi0.py
@@ -1,4 +1,6 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,19 +19,40 @@ from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import (
- CosineDecayWithWarmupSchedulerConfig,
-)
+from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
- # Input / output structure.
- n_obs_steps: int = 1
- chunk_size: int = 50
- n_action_steps: int = 50
+ paligemma_variant: str = "gemma_2b"
+ action_expert_variant: str = "gemma_300m"
+ dtype: str = "float32" # Options: "bfloat16", "float32"
+ n_obs_steps: int = 1
+ chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
+ n_action_steps: int = 50 # Number of action steps to execute
+
+ # Shorter state and action vectors will be padded to these dimensions
+ max_state_dim: int = 32
+ max_action_dim: int = 32
+
+ # Flow matching parameters: see openpi `PI0Pytorch`
+ num_inference_steps: int = 10 # Number of denoising steps during inference
+ time_sampling_beta_alpha: float = 1.5
+ time_sampling_beta_beta: float = 1.0
+ time_sampling_scale: float = 0.999
+ time_sampling_offset: float = 0.001
+ min_period: float = 4e-3
+ max_period: float = 4.0
+
+ image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
+
+ # Add empty images. Used to add empty cameras when no image features are present.
+ empty_cameras: int = 0
+
+ # Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
@@ -38,94 +61,75 @@ class PI0Config(PreTrainedConfig):
}
)
- # Shorter state and action vectors will be padded
- max_state_dim: int = 32
- max_action_dim: int = 32
+ # Training settings
+ gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
+ compile_model: bool = False # Whether to use torch.compile for model optimization
+ compile_mode: str = "max-autotune" # Torch compile mode
+ device: str | None = None # Device to use for the model (None = auto-detect)
- # Image preprocessing
- resize_imgs_with_padding: tuple[int, int] = (224, 224)
-
- # Add empty images. Used by pi0_aloha_sim which adds the empty
- # left and right wrist cameras in addition to the top camera.
- empty_cameras: int = 0
-
- # Converts the joint and gripper values from the standard Aloha space to
- # the space used by the pi internal runtime which was used to train the base model.
- adapt_to_pi_aloha: bool = False
-
- # Converts joint dimensions to deltas with respect to the current state before passing to the model.
- # Gripper dimensions will remain in absolute values.
- use_delta_joint_actions_aloha: bool = False
-
- # Tokenizer
- tokenizer_max_length: int = 48
-
- # Projector
- proj_width: int = 1024
-
- # Decoding
- num_steps: int = 10
-
- # Attention utils
- use_cache: bool = True
- attention_implementation: str = "eager" # or fa2, flex
-
- # Finetuning settings
- freeze_vision_encoder: bool = True
- train_expert_only: bool = False
- train_state_proj: bool = True
-
- # Training presets
- optimizer_lr: float = 2.5e-5
+ # Optimizer settings: see openpi `AdamW``
+ optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
- optimizer_weight_decay: float = 1e-10
+ optimizer_weight_decay: float = 0.01
+ optimizer_grad_clip_norm: float = 1.0
+ # Scheduler settings: see openpi `CosineDecaySchedule`
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
- # TODO: Add EMA
+ tokenizer_max_length: int = 48 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
- # TODO(Steven): Validate device and amp? in all policy configs?
- """Input validation (not exhaustive)."""
+ # Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
- f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
- f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
- )
- if self.n_obs_steps != 1:
- raise ValueError(
- f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
+ f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
- if self.use_delta_joint_actions_aloha:
- raise NotImplementedError(
- "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
- )
+ if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
+ raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
+
+ if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
+ raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
+
+ if self.dtype not in ["bfloat16", "float32"]:
+ raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
- # TODO: implement value error
- # if not self.image_features and not self.env_state_feature:
- # raise ValueError("You must provide at least one image or the environment state among the inputs.")
-
+ """Validate and set up input/output features."""
for i in range(self.empty_cameras):
- key = f"observation.images.empty_camera_{i}"
+ key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
- shape=(3, 480, 640),
+ shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
+ if "observation.state" not in self.input_features:
+ state_feature = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(self.max_state_dim,), # Padded to max_state_dim
+ )
+ self.input_features["observation.state"] = state_feature
+
+ if "action" not in self.output_features:
+ action_feature = PolicyFeature(
+ type=FeatureType.ACTION,
+ shape=(self.max_action_dim,), # Padded to max_action_dim
+ )
+ self.output_features["action"] = action_feature
+
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
+ grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
diff --git a/src/lerobot/policies/pi0/conversion_scripts/benchmark.py b/src/lerobot/policies/pi0/conversion_scripts/benchmark.py
deleted file mode 100644
index c1a488244..000000000
--- a/src/lerobot/policies/pi0/conversion_scripts/benchmark.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.factory import make_policy
-
-torch.backends.cudnn.benchmark = True
-
-
-def main():
- device = "cuda"
- dataset_repo_id = "danaaubakirova/koch_test"
- # model_name = "pi0_base"
- # ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
- ckpt_torch_dir = "lerobot/pi0"
-
- dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
-
- dataloader = torch.utils.data.DataLoader(
- dataset,
- num_workers=0,
- batch_size=1,
- )
-
- batch = next(iter(dataloader))
-
- # To device
- for k in batch:
- if isinstance(batch[k], torch.Tensor):
- batch[k] = batch[k].to(device=device, dtype=torch.float32)
-
- cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
- cfg.pretrained_path = ckpt_torch_dir
- policy = make_policy(cfg, ds_meta=dataset.meta)
-
- # policy = torch.compile(policy, mode="reduce-overhead")
-
- warmup_iters = 10
- benchmark_iters = 30
-
- # Warmup
- for _ in range(warmup_iters):
- torch.cuda.synchronize()
- policy.select_action(batch)
- policy.reset()
- torch.cuda.synchronize()
-
- # Benchmark
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
-
- start_event.record()
- for _ in range(benchmark_iters):
- policy.select_action(batch)
- policy.reset()
- end_event.record()
-
- # Synchronize and measure time
- torch.cuda.synchronize()
- elapsed_time_ms = start_event.elapsed_time(end_event)
-
- avg_time_per_iter = elapsed_time_ms / benchmark_iters
- print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
-
-
-if __name__ == "__main__":
- with torch.inference_mode():
- main()
diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py
deleted file mode 100644
index c0c2e4816..000000000
--- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import pickle
-from pathlib import Path
-
-import torch
-
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
-from lerobot.policies.factory import make_policy
-
-
-def display(tensor: torch.Tensor):
- if tensor.dtype == torch.bool:
- tensor = tensor.float()
- print(f"Shape: {tensor.shape}")
- print(f"Mean: {tensor.mean().item()}")
- print(f"Std: {tensor.std().item()}")
- print(f"Min: {tensor.min().item()}")
- print(f"Max: {tensor.max().item()}")
-
-
-def main():
- num_motors = 14
- device = "cuda"
- # model_name = "pi0_aloha_towel"
- model_name = "pi0_aloha_sim"
-
- if model_name == "pi0_aloha_towel":
- dataset_repo_id = "lerobot/aloha_static_towel"
- else:
- dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
-
- ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
- ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
- save_dir = Path(f"../openpi/data/{model_name}/save")
-
- with open(save_dir / "example.pkl", "rb") as f:
- example = pickle.load(f)
- with open(save_dir / "outputs.pkl", "rb") as f:
- outputs = pickle.load(f)
- with open(save_dir / "noise.pkl", "rb") as f:
- noise = pickle.load(f)
-
- with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
- norm_stats = json.load(f)
-
- # Override stats
- dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
- dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
- norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
- )
- dataset_meta.stats["observation.state"]["std"] = torch.tensor(
- norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
- )
-
- # Create LeRobot batch from Jax
- batch = {}
- for cam_key, uint_chw_array in example["images"].items():
- batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
- batch["observation.state"] = torch.from_numpy(example["state"])
- batch["action"] = torch.from_numpy(outputs["actions"])
- batch["task"] = example["prompt"]
-
- if model_name == "pi0_aloha_towel":
- del batch["observation.images.cam_low"]
- elif model_name == "pi0_aloha_sim":
- batch["observation.images.top"] = batch["observation.images.cam_high"]
- del batch["observation.images.cam_high"]
-
- # Batchify
- for key in batch:
- if isinstance(batch[key], torch.Tensor):
- batch[key] = batch[key].unsqueeze(0)
- elif isinstance(batch[key], str):
- batch[key] = [batch[key]]
- else:
- raise ValueError(f"{key}, {batch[key]}")
-
- # To device
- for k in batch:
- if isinstance(batch[k], torch.Tensor):
- batch[k] = batch[k].to(device=device, dtype=torch.float32)
-
- noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
-
- from lerobot import policies # noqa
-
- cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
- cfg.pretrained_path = ckpt_torch_dir
- policy = make_policy(cfg, dataset_meta)
-
- # loss_dict = policy.forward(batch, noise=noise, time=time_beta)
- # loss_dict["loss"].backward()
- # print("losses")
- # display(loss_dict["losses_after_forward"])
- # print("pi_losses")
- # display(pi_losses)
-
- actions = []
- for _ in range(50):
- action = policy.select_action(batch, noise=noise)
- actions.append(action)
-
- actions = torch.stack(actions, dim=1)
- pi_actions = batch["action"]
- print("actions")
- display(actions)
- print()
- print("pi_actions")
- display(pi_actions)
- print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
- print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
- print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py b/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py
deleted file mode 100644
index 8835da31e..000000000
--- a/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from transformers import GemmaConfig, PaliGemmaConfig
-
-
-def get_paligemma_config(precision: str):
- config = {
- "image_token_index": None,
- "pad_token_id": 0,
- "bos_token_id": 2,
- "eos_token_id": 1,
- }
-
- # image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
-
- image_size = 224 # image_sizes[variant]
- patch_size = 14
- num_image_tokens = (image_size**2) // (patch_size**2)
-
- config["image_token_index"] = 257152
- text_config = {
- "vocab_size": 257152,
- "num_hidden_layers": 18,
- "num_key_value_heads": 1,
- "head_dim": 256,
- "torch_dtype": precision,
- "hidden_size": 2048,
- "hidden_activation": "gelu_pytorch_tanh",
- "num_attention_heads": 8,
- "intermediate_size": 16384,
- "is_encoder_decoder": False,
- }
- vision_config = {
- "torch_dtype": precision,
- "image_size": image_size,
- "patch_size": patch_size,
- "num_image_tokens": num_image_tokens,
- "hidden_size": 1152,
- "intermediate_size": 4304,
- "num_hidden_layers": 27,
- "num_attention_heads": 16,
- "projector_hidden_act": "gelu_fast",
- "vision_use_head": False,
- }
- final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
- return final_config
-
-
-def get_gemma_config(precision: str):
- config = {
- "image_token_index": None,
- "pad_token_id": 0,
- "bos_token_id": 2,
- "eos_token_id": 1,
- }
-
- config["image_token_index"] = 257152
- text_config = {
- "vocab_size": 257152,
- "num_hidden_layers": 18,
- "num_key_value_heads": 1,
- "head_dim": 256,
- "torch_dtype": precision,
- "hidden_size": 1024,
- "hidden_activation": "gelu_pytorch_tanh",
- "num_attention_heads": 8,
- "intermediate_size": 4096,
- "is_encoder_decoder": False,
- }
- final_config = GemmaConfig()
- final_config.update(text_config)
- return final_config
diff --git a/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
deleted file mode 100644
index 742c9ab3f..000000000
--- a/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py
+++ /dev/null
@@ -1,437 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Convert pi0 parameters from Jax to Pytorch
-
-Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
-and install the required libraries.
-
-```bash
-cd ~/code/openpi
-source .venv/bin/activate
-```
-
-Example downloading parameters:
-```bash
-python
->>> import openpi.shared.download as download
->>> path='s3://openpi-assets/checkpoints/pi0_base/params'
->>> download.maybe_download(path)
-```
-
-Converting pi0_base:
-```python
-python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
- --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
- --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
-```
-
-```python
-python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
- --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
- --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
-```
-"""
-
-import argparse
-import pathlib
-
-import jax
-import numpy as np
-import orbax.checkpoint as ocp
-import torch
-from jax.sharding import SingleDeviceSharding
-
-from lerobot.policies.pi0.configuration_pi0 import PI0Config
-from lerobot.policies.pi0.conversion_scripts.conversion_utils import (
- get_gemma_config,
- get_paligemma_config,
-)
-from lerobot.policies.pi0.modeling_pi0 import PI0Policy
-
-PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
-
-
-def slice_paligemma_state_dict(state_dict, config):
- suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
-
- # fmt: off
- # patch embeddings
- state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
- 3, 2, 0, 1
- )
- state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
- # positional embeddings
- state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
- -1, config.vision_config.hidden_size
- )
-
- # extract vision layers to be sliced at index 0. There are 27 layers in the base model.
- encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
- encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
- encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
- encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
-
- encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
- encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
- encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
- encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
-
- encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
- encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
- encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
- encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
- encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
- encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
- encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
- encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
-
- for i in range(config.vision_config.num_hidden_layers):
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
-
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
- state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
-
- state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
- state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
-
- # multimodal projector
-
- state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
- state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
-
- # text decoder (gemma)
- embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
- state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
-
- # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
-
- llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
- llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
- llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
-
- llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
- llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
- # TODO verify correctness of layer norm loading
-
- llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
- llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
-
- for i in range(config.text_config.num_hidden_layers):
- # llm_attention_q_einsum[i].shape = (8, 2048, 256)
- q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
-
- state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
-
- # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
- k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
- state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
- # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
- v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
- state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
-
- # output projection.
-
- # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
- o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
-
- state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
- # mlp layers
- gate_proj_weight = llm_mlp_gating_einsum[i, 0]
- state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
- up_proj_weight = llm_mlp_gating_einsum[i, 1]
- state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
- state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
- state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
- state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
-
- state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
- state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
-
- # fmt: on
- expert_dict = {}
- final_state_dict = {}
- for key, value in state_dict.items():
- if key not in [
- f"llm/final_norm_1/scale{suffix}",
- f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
- f"llm/layers/attn/kv_einsum_1/w{suffix}",
- f"llm/layers/attn/q_einsum_1/w{suffix}",
- f"llm/layers/mlp_1/gating_einsum{suffix}",
- f"llm/layers/mlp_1/linear{suffix}",
- f"llm/layers/pre_attention_norm_1/scale{suffix}",
- f"llm/layers/pre_ffw_norm_1/scale{suffix}",
- ]:
- final_state_dict[key] = torch.from_numpy(value)
- else:
- expert_dict[key] = value
-
- return final_state_dict, expert_dict
-
-
-def slice_gemma_state_dict(state_dict, config, num_expert=1):
- # fmt: off
- # text decoder (gemma)
- # no embedding vector, the expert just has the decoder layers
-
- embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
- state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
-
- # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
-
- suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
-
- llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
- llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
- llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
-
- llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
- llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
- # TODO verify correctness of layer norm loading
-
- llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
- llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
-
- for i in range(config.num_hidden_layers):
- q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
-
- state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
-
- k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
- state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
- v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
- state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
-
- # output projection.
-
- # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
- o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
-
- state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
- # mlp layers
- gate_proj_weight = llm_mlp_gating_einsum[i, 0]
- state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
- up_proj_weight = llm_mlp_gating_einsum[i, 1]
- state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
- state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
- state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
- state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
-
- state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
- state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
-
- # fmt: on
- final_state_dict = {}
- for key, value in state_dict.items():
- if not isinstance(value, torch.Tensor):
- final_state_dict[key] = torch.from_numpy(value)
- else:
- final_state_dict[key] = value
- return final_state_dict
-
-
-def flatten_for_memory(tree, parent_key=""):
- out = {}
- for k, v in tree.items():
- new_key = f"{parent_key}/{k}" if parent_key else k
- if isinstance(v, dict):
- out.update(flatten_for_memory(v, new_key))
- else:
- out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
- return out
-
-
-def flatten_for_npz(tree, parent_key=""):
- out = {}
- for k, v in tree.items():
- new_key = f"{parent_key}/{k}" if parent_key else k
- if isinstance(v, dict):
- out.update(flatten_for_npz(v, new_key))
- else:
- # bf16/f32 here?
- out[new_key] = np.array(v)
- return out
-
-
-def slice_initial_orbax_checkpoint(checkpoint_dir: str):
- params_path = pathlib.Path(checkpoint_dir).resolve()
- checkpointer = ocp.PyTreeCheckpointer()
-
- metadata = checkpointer.metadata(params_path)
- print("Metadata keys:", list(metadata.keys()))
-
- params_name = "params"
-
- item = {params_name: metadata[params_name]}
- device = jax.local_devices()[0] # Use the first local device
- sharding = SingleDeviceSharding(device)
- restored = checkpointer.restore(
- params_path,
- ocp.args.PyTreeRestore(
- item=item,
- restore_args=jax.tree_util.tree_map(
- lambda _: ocp.ArrayRestoreArgs(
- restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
- sharding=sharding,
- ),
- item,
- ),
- transforms={},
- ),
- )
- params = restored[params_name]
-
- # get params for PaliGemma
- pali_params = params["PaliGemma"]
- del params["PaliGemma"]
- pali_params_flat = flatten_for_npz(pali_params)
- return {"paligemma_params": pali_params_flat, "projection_params": params}
-
-
-def update_keys_with_prefix(d: dict, prefix: str) -> dict:
- """Update dictionary keys by adding a prefix."""
- return {f"{prefix}{key}": value for key, value in d.items()}
-
-
-def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
- # Break down orbax ckpts - they are in OCDBT
- initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
- # process projection params
- keys = [
- "state_proj",
- "action_in_proj",
- "action_out_proj",
- "action_time_mlp_in",
- "action_time_mlp_out",
- ]
-
- projection_params = {}
- for key in keys:
- kernel_params = initial_params["projection_params"][key]["kernel"]
- bias_params = initial_params["projection_params"][key]["bias"]
- if isinstance(kernel_params, dict):
- weight = kernel_params["value"]
- bias = bias_params["value"]
- else:
- weight = kernel_params
- bias = bias_params
- projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
- projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
-
- # Process PaliGemma weights
- paligemma_config = get_paligemma_config(precision)
- paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
- initial_params["paligemma_params"], paligemma_config
- )
-
- # Process Gemma weights (at this stage they are unused)
- gemma_config = get_gemma_config(precision)
- gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
-
- # Instantiate model from configs
-
- if "pi0_aloha_sim" in checkpoint_dir:
- pi0_config = PI0Config(
- empty_cameras=2,
- adapt_to_pi_aloha=True,
- use_delta_joint_actions_aloha=False,
- )
- elif "pi0_aloha_towel" in checkpoint_dir:
- pi0_config = PI0Config(
- adapt_to_pi_aloha=True,
- use_delta_joint_actions_aloha=True,
- )
- elif "pi0_base" in checkpoint_dir:
- pi0_config = PI0Config(
- empty_cameras=0,
- adapt_to_pi_aloha=False,
- use_delta_joint_actions_aloha=False,
- )
- else:
- raise ValueError()
-
- # gemma_config=gemma_config, paligemma_config=paligemma_config)
- pi0_model = PI0Policy(pi0_config)
-
- paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
- gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
- projection_params = update_keys_with_prefix(projection_params, "model.")
-
- # load state dict
- torch_dtype = PRECISIONS[precision]
- pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
- pi0_model = pi0_model.to(torch_dtype)
- # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
-
- pi0_model.save_pretrained(output_path, safe_serialization=True)
- # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
-
- # assert that model loads properly
- del pi0_model
- PI0Policy.from_pretrained(output_path)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--checkpoint_dir",
- default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
- type=str,
- help="Path to the ocdbt checkpoint",
- )
-
- parser.add_argument(
- "--precision",
- choices=["float32", "bfloat16", "float16"],
- default="float32",
- type=str,
- help="Precision identifier for model conversion - should match the base checkpoint precision.",
- )
- # tokenizer is identical to paligemma, it appears
-
- parser.add_argument(
- "--tokenizer_hub_id",
- default="google/paligemma-3b-pt-224",
- type=str,
- help="Hub path to the tokenizer to save",
- )
-
- parser.add_argument(
- "--output_path",
- required=True,
- type=str,
- help="Path to save converted weights to",
- )
-
- args = parser.parse_args()
- convert_pi0_checkpoint(
- checkpoint_dir=args.checkpoint_dir,
- precision=args.precision,
- tokenizer_id=args.tokenizer_hub_id,
- output_path=args.output_path,
- )
diff --git a/src/lerobot/policies/pi0/flex_attention.py b/src/lerobot/policies/pi0/flex_attention.py
deleted file mode 100644
index 35628cddb..000000000
--- a/src/lerobot/policies/pi0/flex_attention.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import torch.nn.functional as F # noqa: N812
-from packaging.version import Version
-
-if Version(torch.__version__) > Version("2.5.0"):
- # Ffex attention is only available from torch 2.5 onwards
- from torch.nn.attention.flex_attention import (
- _mask_mod_signature,
- _round_up_to_multiple,
- create_block_mask,
- create_mask,
- flex_attention,
- )
-
-
-# @torch.compile(dynamic=False)
-def flex_attention_forward(
- attention_mask: torch.Tensor,
- batch_size: int,
- head_dim: int,
- query_states: torch.Tensor,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- scaling=None,
-):
- """
- This is defined out of classes to make compile happy.
- """
-
- original_dtype = query_states.dtype
- num_att_heads = 8
- num_key_value_heads = 1
- num_key_value_groups = num_att_heads // num_key_value_heads
-
- key_states = key_states[:, :, :, None, :]
- key_states = key_states.expand(
- batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
- )
- key_states = key_states.reshape(
- batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
- )
-
- value_states = value_states[:, :, :, None, :]
- value_states = value_states.expand(
- batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
- )
- value_states = value_states.reshape(
- batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
- )
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- query_states = query_states.to(torch.float32)
- key_states = key_states.to(torch.float32)
- value_states = value_states.to(torch.float32)
-
- causal_mask = attention_mask
- if causal_mask is not None:
- causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
-
- if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
- causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
-
- def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
- def mask_mod(b, h, q_idx, kv_idx):
- # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
- return precomputed_mask[b][h][q_idx][kv_idx]
-
- return mask_mod
-
- b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
-
- block_size = 128
- q_len_rounded = _round_up_to_multiple(q_len, block_size)
- kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
-
- # *CRITICAL* we do need to expand here, else we get a CUDA index error
-
- pad_q = q_len_rounded - q_len
- pad_k = kv_len_rounded - kv_len
-
- padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
- mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
-
- mask_4d = create_mask(
- mod_fn=mask_mod_fn_orig,
- B=b_mask,
- H=h_mask,
- Q_LEN=q_len_rounded,
- KV_LEN=kv_len_rounded,
- device=causal_mask.device,
- _compile=False,
- )
-
- mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
- block_mask = create_block_mask(
- mask_mod=mask_mod_fn_padded,
- B=b_mask,
- H=h_mask,
- Q_LEN=q_len_rounded,
- KV_LEN=kv_len_rounded,
- BLOCK_SIZE=block_size,
- device=causal_mask.device,
- _compile=False,
- )
-
- # mask is applied inside the kernel, ideally more efficiently than score_mod.
- attn_output, attention_weights = flex_attention(
- query_states,
- key_states,
- value_states,
- block_mask=block_mask,
- enable_gqa=True, # because we shaped query/key states for GQA
- scale=head_dim**-0.5 if scaling is None else scaling,
- return_lse=True,
- )
-
- attn_output = attn_output.to(dtype=original_dtype)
- attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
- attn_output = attn_output.reshape(
- batch_size,
- -1,
- attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
- )
- return attn_output
diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py
index 241509d0b..a2dcdaea3 100644
--- a/src/lerobot/policies/pi0/modeling_pi0.py
+++ b/src/lerobot/policies/pi0/modeling_pi0.py
@@ -14,62 +14,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-π0: A Vision-Language-Action Flow Model for General Robot Control
-
-[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
-[Jax code](https://github.com/Physical-Intelligence/openpi)
-
-Designed by Physical Intelligence. Ported from Jax by Hugging Face.
-
-Install pi0 extra dependencies:
-```bash
-pip install -e ".[pi0]"
-```
-
-Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
-```bash
-python -m lerobot.scripts.train \
---policy.path=lerobot/pi0 \
---dataset.repo_id=danaaubakirova/koch_test
-```
-
-Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
-pretrained with VLM default parameters before pi0 finetuning:
-```bash
-python -m lerobot.scripts.train \
---policy.type=pi0 \
---dataset.repo_id=danaaubakirova/koch_test
-```
-
-Example of using the pi0 pretrained model outside LeRobot training framework:
-```python
-policy = Pi0Policy.from_pretrained("lerobot/pi0")
-```
-
-"""
-
+import builtins
+import logging
import math
from collections import deque
+from pathlib import Path
+from typing import TYPE_CHECKING, Literal
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
-from transformers import AutoTokenizer
-from lerobot.constants import ACTION, OBS_STATE
-from lerobot.policies.normalize import Normalize, Unnormalize
+from lerobot.utils.import_utils import _transformers_available
+
+# Conditional import for type checking and lazy loading
+if TYPE_CHECKING or _transformers_available:
+ from transformers.models.auto import CONFIG_MAPPING
+ from transformers.models.gemma import modeling_gemma
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
+else:
+ CONFIG_MAPPING = None
+ modeling_gemma = None
+ GemmaForCausalLM = None
+ PaliGemmaForConditionalGeneration = None
+
+from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
-from lerobot.policies.pi0.paligemma_with_expert import (
- PaliGemmaWithExpertConfig,
- PaliGemmaWithExpertModel,
+from lerobot.policies.pretrained import PreTrainedPolicy, T
+from lerobot.utils.constants import (
+ ACTION,
+ OBS_LANGUAGE_ATTENTION_MASK,
+ OBS_LANGUAGE_TOKENS,
+ OBS_STATE,
+ OPENPI_ATTENTION_MASK_VALUE,
)
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.utils.utils import get_safe_dtype
-def create_sinusoidal_pos_embedding(
- time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
+def get_safe_dtype(target_dtype, device_type):
+ """Get a safe dtype for the given device type."""
+ if device_type == "mps" and target_dtype == torch.float64:
+ return torch.float32
+ if device_type == "cpu":
+ # CPU doesn't support bfloat16, use float32 instead
+ if target_dtype == torch.bfloat16:
+ return torch.float32
+ if target_dtype == torch.float64:
+ return torch.float64
+ return target_dtype
+
+
+def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
+ time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
@@ -85,17 +81,17 @@ def create_sinusoidal_pos_embedding(
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
- pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
- return pos_emb
+ return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
-def sample_beta(alpha, beta, bsize, device):
- gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
- gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
- return gamma1 / (gamma1 + gamma2)
+def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
+ alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
+ beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
+ dist = torch.distributions.Beta(alpha_t, beta_t)
+ return dist.sample((bsize,))
-def make_att_2d_masks(pad_masks, att_masks):
+def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
@@ -124,438 +120,514 @@ def make_att_2d_masks(pad_masks, att_masks):
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
- att_2d_masks = att_2d_masks & pad_2d_masks
- return att_2d_masks
-
-
-def resize_with_pad(img, width, height, pad_value=-1):
- # assume no-op when width height fits already
- if img.ndim != 4:
- raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
-
- cur_height, cur_width = img.shape[2:]
-
- ratio = max(cur_width / width, cur_height / height)
- resized_height = int(cur_height / ratio)
- resized_width = int(cur_width / ratio)
- resized_img = F.interpolate(
- img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
- )
-
- pad_height = max(0, int(height - resized_height))
- pad_width = max(0, int(width - resized_width))
-
- # pad on left and top of image
- padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
- return padded_img
+ return att_2d_masks & pad_2d_masks
def pad_vector(vector, new_dim):
- """Can be (batch_size x sequence_length x features_dimension)
+ """Pad the last dimension of a vector to new_dim with zeros.
+
+ Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
- if vector.shape[-1] == new_dim:
+ if vector.shape[-1] >= new_dim:
return vector
- shape = list(vector.shape)
- current_dim = shape[-1]
- shape[-1] = new_dim
- new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
- new_vector[..., :current_dim] = vector
- return new_vector
+ return F.pad(vector, (0, new_dim - vector.shape[-1]))
-def normalize(x, min_val, max_val):
- return (x - min_val) / (max_val - min_val)
+def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
+ images: torch.Tensor,
+ height: int,
+ width: int,
+ mode: str = "bilinear",
+) -> torch.Tensor:
+ """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
+ by padding with black. If the image is float32, it must be in the range [-1, 1].
+
+ Args:
+ images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
+ height: Target height
+ width: Target width
+ mode: Interpolation mode ('bilinear', 'nearest', etc.)
+
+ Returns:
+ Resized and padded tensor with same shape format as input
+ """
+ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
+ if images.shape[-1] <= 4: # Assume channels-last format
+ channels_last = True
+ if images.dim() == 3:
+ images = images.unsqueeze(0) # Add batch dimension
+ images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
+ else:
+ channels_last = False
+ if images.dim() == 3:
+ images = images.unsqueeze(0) # Add batch dimension
+
+ batch_size, channels, cur_height, cur_width = images.shape
+
+ # Calculate resize ratio
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+
+ # Resize
+ resized_images = F.interpolate(
+ images,
+ size=(resized_height, resized_width),
+ mode=mode,
+ align_corners=False if mode == "bilinear" else None,
+ )
+
+ # Handle dtype-specific clipping
+ if images.dtype == torch.uint8:
+ resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
+ elif images.dtype == torch.float32:
+ resized_images = resized_images.clamp(-1.0, 1.0)
+ else:
+ raise ValueError(f"Unsupported image dtype: {images.dtype}")
+
+ # Calculate padding
+ pad_h0, remainder_h = divmod(height - resized_height, 2)
+ pad_h1 = pad_h0 + remainder_h
+ pad_w0, remainder_w = divmod(width - resized_width, 2)
+ pad_w1 = pad_w0 + remainder_w
+
+ # Pad
+ constant_value = 0 if images.dtype == torch.uint8 else -1.0
+ padded_images = F.pad(
+ resized_images,
+ (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
+ mode="constant",
+ value=constant_value,
+ )
+
+ # Convert back to original format if needed
+ if channels_last:
+ padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
+
+ return padded_images
-def unnormalize(x, min_val, max_val):
- return x * (max_val - min_val) + min_val
+# Define the complete layer computation function for gradient checkpointing
+def compute_layer_complete(
+ layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
+):
+ models = [paligemma.language_model, gemma_expert.model]
+ query_states = []
+ key_states = []
+ value_states = []
+ gates = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
+ gates.append(gate)
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states.append(query_state)
+ key_states.append(key_state)
+ value_states.append(value_state)
+ # Concatenate and process attention
+ query_states = torch.cat(query_states, dim=2)
+ key_states = torch.cat(key_states, dim=2)
+ value_states = torch.cat(value_states, dim=2)
+ dummy_tensor = torch.zeros(
+ query_states.shape[0],
+ query_states.shape[2],
+ query_states.shape[-1],
+ device=query_states.device,
+ dtype=query_states.dtype,
+ )
+ cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
+ query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, unsqueeze_dim=1
+ )
+ batch_size = query_states.shape[0]
+ scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
+ # Attention computation
+ att_output, _ = modeling_gemma.eager_attention_forward(
+ paligemma.language_model.layers[layer_idx].self_attn,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling,
+ )
+ # Get head_dim from the current layer, not from the model
+ head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
+ att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
+ # Process layer outputs
+ outputs_embeds = []
+ start_pos = 0
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ end_pos = start_pos + hidden_states.shape[1]
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
+ out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
+ # first residual
+ out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
+ after_first_residual = out_emb.clone()
+ out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
+ # Convert to bfloat16 if the next layer (mlp) uses bfloat16
+ if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
+ out_emb = out_emb.to(dtype=torch.bfloat16)
+ out_emb = layer.mlp(out_emb)
+ # second residual
+ out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
+ outputs_embeds.append(out_emb)
+ start_pos = end_pos
+ return outputs_embeds
-def safe_arcsin(value):
- # This ensures that the input stays within
- # [−1,1] to avoid invalid values for arcsin
- return torch.arcsin(torch.clamp(value, -1.0, 1.0))
+class GemmaConfig: # see openpi `gemma.py: Config`
+ """Configuration for Gemma model variants."""
+
+ def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
+ self.width = width
+ self.depth = depth
+ self.mlp_dim = mlp_dim
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
-def aloha_gripper_to_angular(value):
- # Aloha transforms the gripper positions into a linear space. The following code
- # reverses this transformation to be consistent with pi0 which is pretrained in
- # angular space.
- #
- # These values are coming from the Aloha code:
- # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
- value = unnormalize(value, min_val=0.01844, max_val=0.05800)
-
- # This is the inverse of the angular to linear transformation inside the Interbotix code.
- def linear_to_radian(linear_position, arm_length, horn_radius):
- value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
- return safe_arcsin(value)
-
- # The constants are taken from the Interbotix code.
- value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
-
- # Normalize to [0, 1].
- # The values 0.4 and 1.5 were measured on an actual Trossen robot.
- return normalize(value, min_val=0.4, max_val=1.5)
+def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
+ """Returns config for specified gemma variant."""
+ if variant == "gemma_300m":
+ return GemmaConfig(
+ width=1024,
+ depth=18,
+ mlp_dim=4096,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ elif variant == "gemma_2b":
+ return GemmaConfig(
+ width=2048,
+ depth=18,
+ mlp_dim=16_384,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ else:
+ raise ValueError(f"Unknown variant: {variant}")
-def aloha_gripper_from_angular(value):
- # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
- # Note that the units are still angular but the range is different.
-
- # The values 0.4 and 1.5 were measured on an actual Trossen robot.
- value = unnormalize(value, min_val=0.4, max_val=1.5)
-
- # These values are coming from the Aloha code:
- # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
- return normalize(value, min_val=-0.6213, max_val=1.4910)
-
-
-def aloha_gripper_from_angular_inv(value):
- # Directly inverts the gripper_from_angular function.
- value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
- return normalize(value, min_val=0.4, max_val=1.5)
-
-
-class PI0Policy(PreTrainedPolicy):
- """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
-
- config_class = PI0Config
- name = "pi0"
+class PaliGemmaWithExpertModel(
+ nn.Module
+): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
+ """PaliGemma model with action expert for PI0."""
def __init__(
self,
- config: PI0Config,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
+ vlm_config,
+ action_expert_config,
+ use_adarms=None,
+ precision: Literal["bfloat16", "float32"] = "bfloat16",
):
- """
- Args:
- config: Policy configuration class instance or None, in which case the default instantiation of
- the configuration class is used.
- dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
- that they will be passed with a call to `load_state_dict` before the policy is used.
- """
+ if use_adarms is None:
+ use_adarms = [False, False]
+ super().__init__()
- super().__init__(config)
- config.validate_features()
- self.config = config
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
+ vlm_config_hf = CONFIG_MAPPING["paligemma"]()
+ vlm_config_hf._vocab_size = 257152 # noqa: SLF001
+ vlm_config_hf.image_token_index = 257152
+ vlm_config_hf.text_config.hidden_size = vlm_config.width
+ vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
+ vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
+ vlm_config_hf.text_config.head_dim = vlm_config.head_dim
+ vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
+ vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
+ vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
+ vlm_config_hf.text_config.torch_dtype = "float32"
+ vlm_config_hf.text_config.vocab_size = 257152
+ vlm_config_hf.text_config.use_adarms = use_adarms[0]
+ vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
+ vlm_config_hf.vision_config.intermediate_size = 4304
+ vlm_config_hf.vision_config.projection_dim = 2048
+ vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
+ vlm_config_hf.vision_config.torch_dtype = "float32"
+
+ action_expert_config_hf = CONFIG_MAPPING["gemma"](
+ head_dim=action_expert_config.head_dim,
+ hidden_size=action_expert_config.width,
+ intermediate_size=action_expert_config.mlp_dim,
+ num_attention_heads=action_expert_config.num_heads,
+ num_hidden_layers=action_expert_config.depth,
+ num_key_value_heads=action_expert_config.num_kv_heads,
+ vocab_size=257152,
+ hidden_activation="gelu_pytorch_tanh",
+ torch_dtype="float32",
+ use_adarms=use_adarms[1],
+ adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
- self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
- self.model = PI0FlowMatching(config)
+ self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
+ self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
+ self.gemma_expert.model.embed_tokens = None
- self.reset()
+ self.to_bfloat16_for_selected_params(precision)
- def reset(self):
- """This should be called whenever the environment is reset."""
- self._action_queue = deque([], maxlen=self.config.n_action_steps)
+ def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
+ if precision == "bfloat16":
+ self.to(dtype=torch.bfloat16)
+ elif precision == "float32":
+ self.to(dtype=torch.float32)
+ return
+ else:
+ raise ValueError(f"Invalid precision: {precision}")
- def get_optim_params(self) -> dict:
- return self.parameters()
+ params_to_keep_float32 = [
+ "vision_tower.vision_model.embeddings.patch_embedding.weight",
+ "vision_tower.vision_model.embeddings.patch_embedding.bias",
+ "vision_tower.vision_model.embeddings.position_embedding.weight",
+ "input_layernorm",
+ "post_attention_layernorm",
+ "model.norm",
+ ]
- @torch.no_grad
- def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
- """Predict a chunk of actions given environment observations."""
- raise NotImplementedError("Currently not implemented for PI0")
+ for name, param in self.named_parameters():
+ if any(selector in name for selector in params_to_keep_float32):
+ param.data = param.data.to(dtype=torch.float32)
- @torch.no_grad
- def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
- """Select a single action given environment observations.
+ def embed_image(self, image: torch.Tensor):
+ return self.paligemma.model.get_image_features(image)
- This method wraps `select_actions` in order to return one action at a time for execution in the
- environment. It works by managing the actions in a queue and only calling `select_actions` when the
- queue is empty.
- """
- self.eval()
+ def embed_language_tokens(self, tokens: torch.Tensor):
+ return self.paligemma.language_model.embed_tokens(tokens)
- if self.config.adapt_to_pi_aloha:
- batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
-
- batch = self.normalize_inputs(batch)
-
- # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
- # querying the policy.
- if len(self._action_queue) == 0:
- images, img_masks = self.prepare_images(batch)
- state = self.prepare_state(batch)
- lang_tokens, lang_masks = self.prepare_language(batch)
-
- actions = self.model.sample_actions(
- images, img_masks, lang_tokens, lang_masks, state, noise=noise
+ def forward(
+ self,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ inputs_embeds: list[torch.FloatTensor] | None = None,
+ use_cache: bool | None = None,
+ adarms_cond: list[torch.Tensor] | None = None,
+ ):
+ if adarms_cond is None:
+ adarms_cond = [None, None]
+ if inputs_embeds[1] is None:
+ prefix_output = self.paligemma.language_model.forward(
+ inputs_embeds=inputs_embeds[0],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
-
- # Unpad actions
- original_action_dim = self.config.action_feature.shape[0]
- actions = actions[:, :, :original_action_dim]
-
- actions = self.unnormalize_outputs({"action": actions})["action"]
-
- if self.config.adapt_to_pi_aloha:
- actions = self._pi_aloha_encode_actions(actions)
-
- # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
- # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
- self._action_queue.extend(actions.transpose(0, 1))
- return self._action_queue.popleft()
-
- def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
- """Do a full training forward pass to compute the loss"""
- if self.config.adapt_to_pi_aloha:
- batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
- batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
-
- batch = self.normalize_inputs(batch)
- batch = self.normalize_targets(batch)
-
- images, img_masks = self.prepare_images(batch)
- state = self.prepare_state(batch)
- lang_tokens, lang_masks = self.prepare_language(batch)
- actions = self.prepare_action(batch)
- actions_is_pad = batch.get("action_is_pad")
-
- loss_dict = {}
- losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
- loss_dict["losses_after_forward"] = losses.clone()
-
- if actions_is_pad is not None:
- in_episode_bound = ~actions_is_pad
- losses = losses * in_episode_bound.unsqueeze(-1)
- loss_dict["losses_after_in_ep_bound"] = losses.clone()
-
- # Remove padding
- losses = losses[:, :, : self.config.max_action_dim]
- loss_dict["losses_after_rm_padding"] = losses.clone()
-
- # For backward pass
- loss = losses.mean()
- # For logging
- loss_dict["l2_loss"] = loss.item()
-
- return loss, loss_dict
-
- def prepare_images(self, batch):
- """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
- convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
- """
- images = []
- img_masks = []
-
- present_img_keys = [key for key in self.config.image_features if key in batch]
- missing_img_keys = [key for key in self.config.image_features if key not in batch]
-
- if len(present_img_keys) == 0:
- raise ValueError(
- f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
+ prefix_past_key_values = prefix_output.past_key_values
+ prefix_output = prefix_output.last_hidden_state
+ suffix_output = None
+ elif inputs_embeds[0] is None:
+ suffix_output = self.gemma_expert.model.forward(
+ inputs_embeds=inputs_embeds[1],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
+ suffix_output = suffix_output.last_hidden_state
+ prefix_output = None
+ prefix_past_key_values = None
+ else:
+ models = [self.paligemma.language_model, self.gemma_expert.model]
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
- # Preprocess image features present in the batch
- for key in present_img_keys:
- img = batch[key]
+ # Check if gradient checkpointing is enabled for any of the models
+ use_gradient_checkpointing = (
+ hasattr(self.gemma_expert.model, "gradient_checkpointing")
+ and self.gemma_expert.model.gradient_checkpointing
+ and self.training
+ ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
- if self.config.resize_imgs_with_padding is not None:
- img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
+ # Process all layers with gradient checkpointing if enabled
+ for layer_idx in range(num_layers):
+ if use_gradient_checkpointing:
+ inputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_layer_complete,
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ paligemma=self.paligemma,
+ gemma_expert=self.gemma_expert,
+ )
+ else:
+ inputs_embeds = compute_layer_complete(
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ paligemma=self.paligemma,
+ gemma_expert=self.gemma_expert,
+ )
- # Normalize from range [0,1] to [-1,1] as expected by siglip
- img = img * 2.0 - 1.0
+ # final norm
+ def compute_final_norms(inputs_embeds, adarms_cond):
+ outputs_embeds = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
+ outputs_embeds.append(out_emb)
+ return outputs_embeds
- bsize = img.shape[0]
- device = img.device
- mask = torch.ones(bsize, dtype=torch.bool, device=device)
- images.append(img)
- img_masks.append(mask)
+ # Apply gradient checkpointing to final norm if enabled
+ if use_gradient_checkpointing:
+ outputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_final_norms,
+ inputs_embeds,
+ adarms_cond,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ )
+ else:
+ outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
- # Create image features not present in the batch
- # as fully 0 padded images.
- for num_empty_cameras in range(len(missing_img_keys)):
- if num_empty_cameras >= self.config.empty_cameras:
- break
- img = torch.ones_like(img) * -1
- mask = torch.zeros_like(mask)
- images.append(img)
- img_masks.append(mask)
+ prefix_output = outputs_embeds[0]
+ suffix_output = outputs_embeds[1]
+ prefix_past_key_values = None
- return images, img_masks
-
- def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
- """Tokenize the text input"""
- device = batch[OBS_STATE].device
- tasks = batch["task"]
-
- # PaliGemma prompt has to end with a new line
- tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
-
- tokenized_prompt = self.language_tokenizer.__call__(
- tasks,
- padding="max_length",
- padding_side="right",
- max_length=self.config.tokenizer_max_length,
- return_tensors="pt",
- )
- lang_tokens = tokenized_prompt["input_ids"].to(device=device)
- lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
-
- return lang_tokens, lang_masks
-
- def _pi_aloha_decode_state(self, state):
- # Flip the joints.
- for motor_idx in [1, 2, 8, 9]:
- state[:, motor_idx] *= -1
- # Reverse the gripper transformation that is being applied by the Aloha runtime.
- for motor_idx in [6, 13]:
- state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
- return state
-
- def _pi_aloha_encode_actions(self, actions):
- # Flip the joints.
- for motor_idx in [1, 2, 8, 9]:
- actions[:, :, motor_idx] *= -1
- # Reverse the gripper transformation that is being applied by the Aloha runtime.
- for motor_idx in [6, 13]:
- actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
- return actions
-
- def _pi_aloha_encode_actions_inv(self, actions):
- # Flip the joints again.
- for motor_idx in [1, 2, 8, 9]:
- actions[:, :, motor_idx] *= -1
- # Reverse the gripper transformation that is being applied by the Aloha runtime.
- for motor_idx in [6, 13]:
- actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
- return actions
-
- def prepare_state(self, batch):
- """Pad state"""
- state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
- return state
-
- def prepare_action(self, batch):
- """Pad action"""
- actions = pad_vector(batch[ACTION], self.config.max_action_dim)
- return actions
+ return [prefix_output, suffix_output], prefix_past_key_values
-class PI0FlowMatching(nn.Module):
- """
- π0: A Vision-Language-Action Flow Model for General Robot Control
+class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
+ """Core PI0 PyTorch model."""
- [Paper](https://www.physicalintelligence.company/download/pi0.pdf)
- [Jax code](https://github.com/Physical-Intelligence/openpi)
-
- Designed by Physical Intelligence. Ported from Jax by Hugging Face.
- ┌──────────────────────────────┐
- │ actions │
- │ ▲ │
- │ ┌┴─────┐ │
- │ kv cache │Gemma │ │
- │ ┌──────────►│Expert│ │
- │ │ │ │ │
- │ ┌┴────────┐ │x 10 │ │
- │ │ │ └▲──▲──┘ │
- │ │PaliGemma│ │ │ │
- │ │ │ │ robot state │
- │ │ │ noise │
- │ └▲──▲─────┘ │
- │ │ │ │
- │ │ image(s) │
- │ language tokens │
- └──────────────────────────────┘
- """
-
- def __init__(self, config):
+ def __init__(self, config: PI0Config):
super().__init__()
self.config = config
- paligemma_with_export_config = PaliGemmaWithExpertConfig(
- freeze_vision_encoder=self.config.freeze_vision_encoder,
- train_expert_only=self.config.train_expert_only,
- attention_implementation=self.config.attention_implementation,
+ paligemma_config = get_gemma_config(config.paligemma_variant)
+ action_expert_config = get_gemma_config(config.action_expert_variant)
+
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
+ paligemma_config,
+ action_expert_config,
+ use_adarms=[False, False],
+ precision=config.dtype,
)
- self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
- # Projections are float32
- self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
- self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
- self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
+ self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
+ self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
- self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
- self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
+ self.state_proj = nn.Linear(config.max_state_dim, action_expert_config.width)
+ self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
+ self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
- self.set_requires_grad()
+ # Initialize gradient checkpointing flag
+ self.gradient_checkpointing_enabled = False
- def set_requires_grad(self):
- for params in self.state_proj.parameters():
- params.requires_grad = self.config.train_state_proj
+ # Compile model if requested
+ if config.compile_model:
+ torch.set_float32_matmul_precision("high")
+ self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
+ # Also compile the main forward pass used during training
+ self.forward = torch.compile(self.forward, mode=config.compile_mode)
+
+ msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
+
+ try:
+ from transformers.models.siglip import check
+
+ if not check.check_whether_transformers_replace_is_installed_correctly():
+ raise ValueError(msg)
+ except ImportError:
+ raise ValueError(msg) from None
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.gradient_checkpointing_enabled = True
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
+ logging.info("Enabled gradient checkpointing for PI0Pytorch model")
+
+ def gradient_checkpointing_disable(self):
+ """Disable gradient checkpointing."""
+ self.gradient_checkpointing_enabled = False
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
+ logging.info("Disabled gradient checkpointing for PI0Pytorch model")
+
+ def _apply_checkpoint(self, func, *args, **kwargs):
+ """Helper method to apply gradient checkpointing if enabled."""
+ if self.gradient_checkpointing_enabled and self.training:
+ return torch.utils.checkpoint.checkpoint(
+ func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
+ )
+ return func(*args, **kwargs)
+
+ def _prepare_attention_masks_4d(self, att_2d_masks):
+ """Helper method to prepare 4D attention masks for transformer."""
+ att_2d_masks_4d = att_2d_masks[:, None, :, :]
+ return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
def sample_noise(self, shape, device):
- noise = torch.normal(
+ return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
- return noise
def sample_time(self, bsize, device):
- time_beta = sample_beta(1.5, 1.0, bsize, device)
- time = time_beta * 0.999 + 0.001
+ time_beta = sample_beta(
+ self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
+ )
+ time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Embed images with SigLIP and language tokens with embedding layer to prepare
- for PaliGemma transformer processing.
- """
- # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
+ """Embed images with SigLIP and language tokens with embedding layer."""
embs = []
pad_masks = []
att_masks = []
- # TODO: remove for loop
- for (
- img,
- img_mask,
- ) in zip(images, img_masks, strict=False):
- img_emb = self.paligemma_with_expert.embed_image(img)
- img_emb = img_emb.to(dtype=torch.bfloat16)
+ # Process images
+ for img, img_mask in zip(images, img_masks, strict=True):
- # Normalize image embeddings
- img_emb_dim = img_emb.shape[-1]
- img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
+ def image_embed_func(img):
+ return self.paligemma_with_expert.embed_image(img)
+ img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
- img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
- pad_masks.append(img_mask)
-
- # Create attention masks so that image tokens attend to each other
+ pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_masks += [0] * num_img_embs
- lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
-
- # Normalize language embeddings
- lang_emb_dim = lang_emb.shape[-1]
- lang_emb = lang_emb * math.sqrt(lang_emb_dim)
+ # Process language tokens
+ def lang_embed_func(lang_tokens):
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
+ lang_emb_dim = lang_emb.shape[-1]
+ return lang_emb * math.sqrt(lang_emb_dim)
+ lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
embs.append(lang_emb)
pad_masks.append(lang_masks)
- # full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
+
+ bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
@@ -566,57 +638,67 @@ class PI0FlowMatching(nn.Module):
pad_masks = []
att_masks = []
- # Embed state
- state_emb = self.state_proj(state)
- state_emb = state_emb.to(dtype=torch.bfloat16)
+ if self.state_proj.weight.dtype == torch.float32:
+ state = state.to(torch.float32)
+
+ def state_proj_func(state):
+ return self.state_proj(state)
+
+ state_emb = self._apply_checkpoint(state_proj_func, state)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
- dtype = state_emb.dtype
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
-
- # Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
- # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
+ # Embed timestep using sine-cosine positional encoding
time_emb = create_sinusoidal_pos_embedding(
- timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
+ timestep,
+ self.action_in_proj.out_features,
+ min_period=self.config.min_period,
+ max_period=self.config.max_period,
+ device=timestep.device,
)
- time_emb = time_emb.type(dtype=dtype)
+ time_emb = time_emb.type(dtype=timestep.dtype)
# Fuse timestep + action information using an MLP
- action_emb = self.action_in_proj(noisy_actions)
+ def action_proj_func(noisy_actions):
+ return self.action_in_proj(noisy_actions)
+
+ action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
- action_time_emb = self.action_time_mlp_in(action_time_emb)
- action_time_emb = F.silu(action_time_emb) # swish == silu
- action_time_emb = self.action_time_mlp_out(action_time_emb)
+ def mlp_func(action_time_emb):
+ x = self.action_time_mlp_in(action_time_emb)
+ x = F.silu(x)
+ return self.action_time_mlp_out(x)
+
+ action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
+ adarms_cond = None
- # Add to input tokens
embs.append(action_time_emb)
-
bsize, action_time_dim = action_time_emb.shape[:2]
- action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
+ action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
- att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
+ att_masks += [1] + ([0] * (self.config.chunk_size - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
- return embs, pad_masks, att_masks
+ return embs, pad_masks, att_masks, adarms_cond
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
- """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
+ """Do a full training forward pass and compute the loss."""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
@@ -630,7 +712,14 @@ class PI0FlowMatching(nn.Module):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
- suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
+
+ if (
+ self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
+ == torch.bfloat16
+ ):
+ suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
+ prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
@@ -638,29 +727,51 @@ class PI0FlowMatching(nn.Module):
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
- (_, suffix_out), _ = self.paligemma_with_expert.forward(
- attention_mask=att_2d_masks,
- position_ids=position_ids,
- past_key_values=None,
- inputs_embeds=[prefix_embs, suffix_embs],
- use_cache=False,
- fill_kv_cache=False,
+ att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
+
+ def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
+ attention_mask=att_2d_masks_4d,
+ position_ids=position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, suffix_embs],
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
+ )
+ return suffix_out
+
+ suffix_out = self._apply_checkpoint(
+ forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
- suffix_out = suffix_out[:, -self.config.n_action_steps :]
- # Original openpi code, upcast attention output
+
+ suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
- v_t = self.action_out_proj(suffix_out)
- losses = F.mse_loss(u_t, v_t, reduction="none")
- return losses
+ def action_out_proj_func(suffix_out):
+ return self.action_out_proj(suffix_out)
+
+ v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
+
+ return F.mse_loss(u_t, v_t, reduction="none")
+
+ @torch.no_grad() # see openpi `sample_actions` (slightly adapted)
+ def sample_actions(
+ self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
+ ) -> Tensor:
+ """Do a full inference forward and compute the action."""
+ if num_steps is None:
+ num_steps = self.config.num_inference_steps
- def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
- """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
- actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
+ # Sample noise with padded dimension as expected by action_in_proj
+ actions_shape = (
+ bsize,
+ self.config.chunk_size,
+ self.config.max_action_dim,
+ ) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -669,17 +780,18 @@ class PI0FlowMatching(nn.Module):
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
- # Compute image and language key value cache
+ prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
+ self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
+
_, past_key_values = self.paligemma_with_expert.forward(
- attention_mask=prefix_att_2d_masks,
+ attention_mask=prefix_att_2d_masks_4d,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
- use_cache=self.config.use_cache,
- fill_kv_cache=True,
+ use_cache=True,
)
- dt = -1.0 / self.config.num_steps
+ dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
@@ -693,10 +805,9 @@ class PI0FlowMatching(nn.Module):
x_t,
expanded_time,
)
-
- # Euler step
- x_t += dt * v_t
+ x_t = x_t + dt * v_t
time += dt
+
return x_t
def denoise_step(
@@ -708,30 +819,374 @@ class PI0FlowMatching(nn.Module):
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
- suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
+
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
-
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
-
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
+ full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
+ self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
+
outputs_embeds, _ = self.paligemma_with_expert.forward(
- attention_mask=full_att_2d_masks,
+ attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
- use_cache=self.config.use_cache,
- fill_kv_cache=False,
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
)
+
suffix_out = outputs_embeds[1]
- suffix_out = suffix_out[:, -self.config.n_action_steps :]
+ suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
- v_t = self.action_out_proj(suffix_out)
- return v_t
+ return self.action_out_proj(suffix_out)
+
+
+class PI0Policy(PreTrainedPolicy):
+ """PI0 OpenPI Policy for LeRobot."""
+
+ config_class = PI0Config
+ name = "pi0"
+
+ def __init__(
+ self,
+ config: PI0Config,
+ ):
+ """
+ Args:
+ config: Policy configuration class instance.
+ """
+ super().__init__(config)
+ config.validate_features()
+ self.config = config
+
+ # Initialize the core PI0 model
+ self.model = PI0Pytorch(config)
+
+ # Enable gradient checkpointing if requested
+ if config.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable()
+
+ self.model.to(config.device)
+
+ self.reset()
+
+ @classmethod
+ def from_pretrained(
+ cls: builtins.type[T],
+ pretrained_name_or_path: str | Path,
+ *,
+ config: PreTrainedConfig | None = None,
+ force_download: bool = False,
+ resume_download: bool | None = None,
+ proxies: dict | None = None,
+ token: str | bool | None = None,
+ cache_dir: str | Path | None = None,
+ local_files_only: bool = False,
+ revision: str | None = None,
+ strict: bool = True,
+ **kwargs,
+ ) -> T:
+ """Override the from_pretrained method to handle key remapping and display important disclaimer."""
+ print(
+ "The PI05 model is a direct port of the OpenPI implementation. \n"
+ "This implementation follows the original OpenPI structure for compatibility. \n"
+ "Original implementation: https://github.com/Physical-Intelligence/openpi"
+ )
+ if pretrained_name_or_path is None:
+ raise ValueError("pretrained_name_or_path is required")
+
+ # Use provided config if available, otherwise create default config
+ if config is None:
+ config = PreTrainedConfig.from_pretrained(
+ pretrained_name_or_path=pretrained_name_or_path,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ revision=revision,
+ **kwargs,
+ )
+
+ # Initialize model without loading weights
+ # Check if dataset_stats were provided in kwargs
+ model = cls(config, **kwargs)
+
+ # Now manually load and remap the state dict
+ try:
+ # Try to load the pytorch_model.bin or model.safetensors file
+ print(f"Loading model from: {pretrained_name_or_path}")
+ try:
+ from transformers.utils import cached_file
+
+ # Try safetensors first
+ resolved_file = cached_file(
+ pretrained_name_or_path,
+ "model.safetensors",
+ cache_dir=kwargs.get("cache_dir"),
+ force_download=kwargs.get("force_download", False),
+ resume_download=kwargs.get("resume_download"),
+ proxies=kwargs.get("proxies"),
+ use_auth_token=kwargs.get("use_auth_token"),
+ revision=kwargs.get("revision"),
+ local_files_only=kwargs.get("local_files_only", False),
+ )
+ from safetensors.torch import load_file
+
+ original_state_dict = load_file(resolved_file)
+ print("✓ Loaded state dict from model.safetensors")
+ except Exception as e:
+ print(f"Could not load state dict from remote files: {e}")
+ print("Returning model without loading pretrained weights")
+ return model
+
+ # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
+ fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
+
+ # Then add "model." prefix for all keys that don't already have it
+ remapped_state_dict = {}
+ remap_count = 0
+
+ for key, value in fixed_state_dict.items():
+ if not key.startswith("model."):
+ new_key = f"model.{key}"
+ remapped_state_dict[new_key] = value
+ remap_count += 1
+ else:
+ remapped_state_dict[key] = value
+
+ if remap_count > 0:
+ print(f"Remapped {remap_count} state dict keys")
+
+ # Load the remapped state dict into the model
+ missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
+
+ if missing_keys:
+ print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
+ if len(missing_keys) <= 5:
+ for key in missing_keys:
+ print(f" - {key}")
+ else:
+ for key in missing_keys[:5]:
+ print(f" - {key}")
+ print(f" ... and {len(missing_keys) - 5} more")
+
+ if unexpected_keys:
+ print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
+ if len(unexpected_keys) <= 5:
+ for key in unexpected_keys:
+ print(f" - {key}")
+ else:
+ for key in unexpected_keys[:5]:
+ print(f" - {key}")
+ print(f" ... and {len(unexpected_keys) - 5} more")
+
+ if not missing_keys and not unexpected_keys:
+ print("All keys loaded successfully!")
+
+ except Exception as e:
+ print(f"Warning: Could not remap state dict keys: {e}")
+
+ return model
+
+ def _fix_pytorch_state_dict_keys(
+ self, state_dict, model_config
+ ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
+ """Fix state dict keys to match current model architecture."""
+ import re
+
+ fixed_state_dict = {}
+
+ for key, value in state_dict.items():
+ new_key = key
+
+ # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
+ # For gemma expert layers
+ if re.match(
+ r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
+ key,
+ ):
+ # Check if the model actually has adaRMS enabled for the expert
+ expert_uses_adarms = getattr(
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
+ )
+ if expert_uses_adarms:
+ logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
+ continue
+
+ if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
+ # Check if the model actually has adaRMS enabled for the expert
+ expert_uses_adarms = getattr(
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
+ )
+ if expert_uses_adarms:
+ logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
+ continue
+
+ # Handle MLP naming changes for pi0
+ # non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_*
+ if key.startswith("time_mlp_in."):
+ new_key = key.replace("time_mlp_in.", "action_time_mlp_in.")
+ elif key.startswith("time_mlp_out."):
+ new_key = key.replace("time_mlp_out.", "action_time_mlp_out.")
+
+ # Handle vision tower embedding layer potential differences
+ if "patch_embedding" in key:
+ # Some checkpoints might have this, but current model expects different structure
+ logging.warning(f"Vision embedding key might need handling: {key}")
+
+ fixed_state_dict[new_key] = value
+
+ return fixed_state_dict
+
+ def get_optim_params(self) -> dict:
+ return self.parameters()
+
+ def reset(self):
+ """Reset internal state - called when environment resets."""
+ self._action_queue = deque(maxlen=self.config.n_action_steps)
+ self._queues = {
+ ACTION: deque(maxlen=self.config.n_action_steps),
+ }
+
+ def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
+ """Preprocess images for the model.
+
+ Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
+ PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
+ """
+ images = []
+ img_masks = []
+
+ # Get device from model parameters
+ device = next(self.parameters()).device
+
+ present_img_keys = [key for key in self.config.image_features if key in batch]
+ missing_img_keys = [key for key in self.config.image_features if key not in batch]
+
+ if len(present_img_keys) == 0:
+ raise ValueError(
+ f"All image features are missing from the batch. At least one expected. "
+ f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
+ )
+
+ for key in present_img_keys:
+ img = batch[key]
+
+ # Ensure tensor is on the same device as the model
+ if img.device != device:
+ img = img.to(device)
+
+ # Ensure float32 dtype for consistency
+ if img.dtype != torch.float32:
+ img = img.to(torch.float32)
+
+ # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
+ is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
+
+ if is_channels_first:
+ # Convert [B, C, H, W] to [B, H, W, C] for processing
+ img = img.permute(0, 2, 3, 1)
+
+ # from openpi preprocess_observation_pytorch: Resize with padding if needed
+ if img.shape[1:3] != self.config.image_resolution:
+ img = resize_with_pad_torch(img, *self.config.image_resolution)
+
+ # Normalize from [0,1] to [-1,1] as expected by siglip
+ img = img * 2.0 - 1.0
+
+ # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
+ if is_channels_first:
+ img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
+
+ images.append(img)
+ # Create mask (all ones for real images)
+ bsize = img.shape[0]
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
+ img_masks.append(mask)
+
+ # Create image features not present in the batch as fully 0 padded images
+ for _num_empty_cameras in range(len(missing_img_keys)):
+ img = torch.ones_like(img) * -1 # padded with -1 for SigLIP
+ mask = torch.zeros_like(mask) # mask is zero for empty cameras
+ images.append(img)
+ img_masks.append(mask)
+
+ return images, img_masks
+
+ def prepare_state(self, batch):
+ """Pad state"""
+ state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
+ return state
+
+ def prepare_action(self, batch):
+ """Pad action"""
+ actions = pad_vector(batch[ACTION], self.config.max_action_dim)
+ return actions
+
+ @torch.no_grad()
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
+ """Select a single action given environment observations."""
+ self.eval()
+
+ # Action queue logic for n_action_steps > 1
+ if len(self._action_queue) == 0:
+ actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
+ # Transpose to get shape (n_action_steps, batch_size, action_dim)
+ self._action_queue.extend(actions.transpose(0, 1))
+
+ return self._action_queue.popleft()
+
+ @torch.no_grad()
+ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
+ """Predict a chunk of actions given environment observations."""
+ self.eval()
+
+ # Prepare inputs
+ images, img_masks = self._preprocess_images(batch)
+ lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
+ state = self.prepare_state(batch)
+
+ # Sample actions using the model
+ actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
+
+ # Unpad actions to actual action dimension
+ original_action_dim = self.config.output_features[ACTION].shape[0]
+ actions = actions[:, :, :original_action_dim]
+
+ return actions
+
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
+ """Run the batch through the model and compute the loss for training."""
+
+ # Prepare inputs
+ images, img_masks = self._preprocess_images(batch)
+ lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
+ state = self.prepare_state(batch)
+ actions = self.prepare_action(batch)
+
+ # Compute loss
+ losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
+
+ # Truncate losses to actual action dimensions
+ original_action_dim = self.config.output_features[ACTION].shape[0]
+ losses = losses[:, :, :original_action_dim]
+
+ loss = losses.mean()
+
+ loss_dict = {
+ "loss": loss.item(),
+ "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
+ }
+
+ return loss, loss_dict
diff --git a/src/lerobot/policies/pi0/paligemma_with_expert.py b/src/lerobot/policies/pi0/paligemma_with_expert.py
deleted file mode 100644
index f0f5713e5..000000000
--- a/src/lerobot/policies/pi0/paligemma_with_expert.py
+++ /dev/null
@@ -1,421 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import List, Optional, Union
-
-import torch
-import torch.version
-from pytest import Cache
-from torch import nn
-from transformers import (
- AutoConfig,
- GemmaForCausalLM,
- PaliGemmaForConditionalGeneration,
- PretrainedConfig,
- PreTrainedModel,
-)
-from transformers.models.auto import CONFIG_MAPPING
-
-from lerobot.policies.pi0.flex_attention import flex_attention_forward
-
-
-def apply_rope(x, positions, max_wavelength=10_000):
- """
- Applies RoPE positions [B, L] to x [B, L, H, D].
- """
- d_half = x.shape[-1] // 2
- device = x.device
- dtype = x.dtype
- x = x.to(torch.float32)
-
- freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
- timescale = max_wavelength**freq_exponents
- radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
-
- radians = radians[..., None, :]
-
- sin = torch.sin(radians) # .to(dtype=dtype)
- cos = torch.cos(radians) # .to(dtype=dtype)
-
- x1, x2 = x.split(d_half, dim=-1)
- res = torch.empty_like(x)
- res[..., :d_half] = x1 * cos - x2 * sin
- res[..., d_half:] = x2 * cos + x1 * sin
-
- return res.to(dtype)
-
-
-class PaliGemmaWithExpertConfig(PretrainedConfig):
- model_type = "PaliGemmaWithExpertModel"
- sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
-
- def __init__(
- self,
- paligemma_config: dict | None = None,
- gemma_expert_config: dict | None = None,
- freeze_vision_encoder: bool = True,
- train_expert_only: bool = True,
- attention_implementation: str = "eager",
- **kwargs,
- ):
- self.freeze_vision_encoder = freeze_vision_encoder
- self.train_expert_only = train_expert_only
- self.attention_implementation = attention_implementation
-
- if paligemma_config is None:
- # Default config from Pi0
- self.paligemma_config = CONFIG_MAPPING["paligemma"](
- transformers_version="4.48.1",
- _vocab_size=257152,
- bos_token_id=2,
- eos_token_id=1,
- hidden_size=2048,
- image_token_index=257152,
- model_type="paligemma",
- pad_token_id=0,
- projection_dim=2048,
- text_config={
- "hidden_activation": "gelu_pytorch_tanh",
- "hidden_size": 2048,
- "intermediate_size": 16384,
- "model_type": "gemma",
- "num_attention_heads": 8,
- "num_hidden_layers": 18,
- "num_image_tokens": 256,
- "num_key_value_heads": 1,
- "torch_dtype": "float32",
- "vocab_size": 257152,
- },
- vision_config={
- "hidden_size": 1152,
- "intermediate_size": 4304,
- "model_type": "siglip_vision_model",
- "num_attention_heads": 16,
- "num_hidden_layers": 27,
- "num_image_tokens": 256,
- "patch_size": 14,
- "projection_dim": 2048,
- "projector_hidden_act": "gelu_fast",
- "torch_dtype": "float32",
- "vision_use_head": False,
- },
- )
- elif isinstance(self.paligemma_config, dict):
- # Override Pi0 default config for PaliGemma
- if "model_type" not in gemma_expert_config:
- paligemma_config["model_type"] = "paligemma"
-
- cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
- self.paligemma_config = cfg_cls(**paligemma_config)
-
- if gemma_expert_config is None:
- # Default config from Pi0
- self.gemma_expert_config = CONFIG_MAPPING["gemma"](
- attention_bias=False,
- attention_dropout=0.0,
- bos_token_id=2,
- eos_token_id=1,
- head_dim=256,
- hidden_act="gelu_pytorch_tanh",
- hidden_activation="gelu_pytorch_tanh",
- hidden_size=1024,
- initializer_range=0.02,
- intermediate_size=4096,
- max_position_embeddings=8192,
- model_type="gemma",
- num_attention_heads=8,
- num_hidden_layers=18,
- num_key_value_heads=1,
- pad_token_id=0,
- rms_norm_eps=1e-06,
- rope_theta=10000.0,
- torch_dtype="float32",
- transformers_version="4.48.1",
- use_cache=True,
- vocab_size=257152,
- )
- elif isinstance(self.gemma_expert_config, dict):
- # Override Pi0 default config for Gemma Expert
- if "model_type" not in gemma_expert_config:
- gemma_expert_config["model_type"] = "gemma"
-
- cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
- self.gemma_expert_config = cfg_cls(**gemma_expert_config)
-
- super().__init__(**kwargs)
-
- def __post_init__(self):
- super().__post_init__()
- if self.train_expert_only and not self.freeze_vision_encoder:
- raise ValueError(
- "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
- )
-
- if self.attention_implementation not in ["eager", "fa2", "flex"]:
- raise ValueError(
- f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
- )
-
-
-class PaliGemmaWithExpertModel(PreTrainedModel):
- config_class = PaliGemmaWithExpertConfig
-
- def __init__(self, config: PaliGemmaWithExpertConfig):
- super().__init__(config=config)
- self.config = config
- self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
- self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
- # Remove unused embed_tokens
- self.gemma_expert.model.embed_tokens = None
-
- self.to_bfloat16_like_physical_intelligence()
- self.set_requires_grad()
-
- def set_requires_grad(self):
- if self.config.freeze_vision_encoder:
- self.paligemma.vision_tower.eval()
- for params in self.paligemma.vision_tower.parameters():
- params.requires_grad = False
-
- if self.config.train_expert_only:
- self.paligemma.eval()
- for params in self.paligemma.parameters():
- params.requires_grad = False
-
- def train(self, mode: bool = True):
- super().train(mode)
-
- if self.config.freeze_vision_encoder:
- self.paligemma.vision_tower.eval()
-
- if self.config.train_expert_only:
- self.paligemma.eval()
-
- def to_bfloat16_like_physical_intelligence(self):
- self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
-
- params_to_change_dtype = [
- "language_model.model.layers",
- "gemma_expert.model.layers",
- "vision_tower",
- "multi_modal",
- ]
- for name, param in self.named_parameters():
- if any(selector in name for selector in params_to_change_dtype):
- param.data = param.data.to(dtype=torch.bfloat16)
-
- def embed_image(self, image: torch.Tensor):
- # Handle different transformers versions
- if hasattr(self.paligemma, "get_image_features"):
- return self.paligemma.get_image_features(image)
- else:
- return self.paligemma.model.get_image_features(image)
-
- def embed_language_tokens(self, tokens: torch.Tensor):
- return self.paligemma.language_model.embed_tokens(tokens)
-
- # TODO: break down this huge forward into modules or functions
- def forward(
- self,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
- inputs_embeds: List[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- fill_kv_cache: Optional[bool] = None,
- ):
- models = [self.paligemma.language_model, self.gemma_expert.model]
-
- for hidden_states in inputs_embeds:
- # TODO this is very inefficient
- # dtype is always the same, batch size too (if > 1 len)
- # device could be trickier in multi gpu edge cases but that's it
- if hidden_states is None:
- continue
- batch_size = hidden_states.shape[0]
-
- # RMSNorm
- num_layers = self.paligemma.config.text_config.num_hidden_layers
- head_dim = self.paligemma.config.text_config.head_dim
- for layer_idx in range(num_layers):
- query_states = []
- key_states = []
- value_states = []
- for i, hidden_states in enumerate(inputs_embeds):
- if hidden_states is None:
- continue
- layer = models[i].layers[layer_idx]
- # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
- # hidden_states = hidden_states * normalizer
- hidden_states = layer.input_layernorm(hidden_states)
-
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
-
- hidden_states = hidden_states.to(dtype=torch.bfloat16)
- query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
- key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
- value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
-
- query_states.append(query_state)
- key_states.append(key_state)
- value_states.append(value_state)
-
- # B,L,H,D with L sequence length, H number of heads, D head dim
- # concatenate on the number of embeddings/tokens
- query_states = torch.cat(query_states, dim=1)
- key_states = torch.cat(key_states, dim=1)
- value_states = torch.cat(value_states, dim=1)
-
- query_states = apply_rope(query_states, position_ids)
- key_states = apply_rope(key_states, position_ids)
-
- if use_cache and past_key_values is None:
- past_key_values = {}
-
- if use_cache:
- if fill_kv_cache:
- past_key_values[layer_idx] = {
- "key_states": key_states,
- "value_states": value_states,
- }
- else:
- # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
- # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
- # the max len, then we (for instance) double the cache size. This implementation already exists
- # in `transformers`. (molbap)
- key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
- value_states = torch.cat(
- [past_key_values[layer_idx]["value_states"], value_states], dim=1
- )
-
- attention_interface = self.get_attention_interface()
- att_output = attention_interface(
- attention_mask, batch_size, head_dim, query_states, key_states, value_states
- )
- att_output = att_output.to(dtype=torch.bfloat16)
-
- # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
- outputs_embeds = []
- start = 0
- for i, hidden_states in enumerate(inputs_embeds):
- layer = models[i].layers[layer_idx]
-
- if hidden_states is not None:
- end = start + hidden_states.shape[1]
-
- if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
- att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
- out_emb = layer.self_attn.o_proj(att_output[:, start:end])
-
- # TODO: first dropout (by default 0.0)
-
- # first residual
- out_emb += hidden_states
- after_first_residual = out_emb.clone()
-
- out_emb = layer.post_attention_layernorm(out_emb)
- out_emb = layer.mlp(out_emb)
-
- # TODO: second dropout (by default 0.0)
-
- # second residual
- out_emb += after_first_residual
-
- outputs_embeds.append(out_emb)
-
- start = end
- else:
- outputs_embeds.append(None)
-
- inputs_embeds = outputs_embeds
-
- # final norm
- outputs_embeds = []
- for i, hidden_states in enumerate(inputs_embeds):
- if hidden_states is not None:
- out_emb = models[i].norm(hidden_states)
- outputs_embeds.append(out_emb)
- else:
- outputs_embeds.append(None)
-
- return outputs_embeds, past_key_values
-
- def get_attention_interface(self):
- if self.config.attention_implementation == "fa2":
- attention_interface = self.flash_attention_forward
- elif self.config.attention_implementation == "flex":
- attention_interface = flex_attention_forward
- else:
- attention_interface = self.eager_attention_forward
- return attention_interface
-
- def flash_attention_forward(
- self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
- ):
- raise NotImplementedError("FA2 is not implemented (yet)")
-
- def eager_attention_forward(
- self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
- ):
- num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
- num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
- num_key_value_groups = num_att_heads // num_key_value_heads
-
- # query_states: batch_size, sequence_length, num_att_head, head_dim
- # key_states: batch_size, sequence_length, num_key_value_head, head_dim
- # value_states: batch_size, sequence_length, num_key_value_head, head_dim
- sequence_length = key_states.shape[1]
-
- key_states = key_states[:, :, :, None, :].expand(
- batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
- )
- key_states = key_states.reshape(
- batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
- )
-
- value_states = value_states[:, :, :, None, :].expand(
- batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
- )
- value_states = value_states.reshape(
- batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
- )
-
- # Attention here is upcasted to float32 to match the original eager implementation.
-
- query_states = query_states.to(dtype=torch.float32)
- key_states = key_states.to(dtype=torch.float32)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
-
- att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
- att_weights *= head_dim**-0.5
- big_neg = -2.3819763e38 # See gemma/modules.py
-
- masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
-
- probs = nn.functional.softmax(masked_att_weights, dim=-1)
- probs = probs.to(dtype=value_states.dtype)
-
- # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
- # value_states: batch_size, sequence_length, num_att_heads, head_dim
-
- att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
-
- att_output = att_output.permute(0, 2, 1, 3)
- # we use -1 because sequence length can change
- att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
-
- return att_output
diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py
new file mode 100644
index 000000000..50f5dec83
--- /dev/null
+++ b/src/lerobot/policies/pi0/processor_pi0.py
@@ -0,0 +1,166 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import torch
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.policies.pi0.configuration_pi0 import PI0Config
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ ComplementaryDataProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ ProcessorStep,
+ ProcessorStepRegistry,
+ RenameObservationsProcessorStep,
+ TokenizerProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+@ProcessorStepRegistry.register(name="pi0_new_line_processor")
+class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
+ """
+ Ensures that the task description string ends with a newline character.
+
+ This processing step is required for compatibility with the PaliGemma tokenizer,
+ which expects a newline at the end of the text prompt. It handles both single
+ strings and lists of strings for the 'task' key in complementary data.
+ """
+
+ def complementary_data(self, complementary_data):
+ """
+ Adds a newline to the 'task' field if it doesn't already have one.
+
+ Args:
+ complementary_data: A dictionary that may contain a 'task' key with a
+ string or list of strings.
+
+ Returns:
+ A new dictionary with the modified 'task' field.
+ """
+ if "task" not in complementary_data:
+ return complementary_data
+
+ task = complementary_data["task"]
+ if task is None:
+ return complementary_data
+
+ new_complementary_data = dict(complementary_data)
+
+ # Handle both string and list of strings
+ if isinstance(task, str):
+ # Single string: add newline if not present
+ if not task.endswith("\n"):
+ new_complementary_data["task"] = f"{task}\n"
+ elif isinstance(task, list) and all(isinstance(t, str) for t in task):
+ # List of strings: add newline to each if not present
+ new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
+ # If task is neither string nor list of strings, leave unchanged
+
+ return new_complementary_data
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ This step does not alter the feature definitions.
+
+ Args:
+ features: The input feature dictionary.
+
+ Returns:
+ The unchanged feature dictionary.
+ """
+ return features
+
+
+def make_pi0_pre_post_processors(
+ config: PI0Config,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the PI0 policy.
+
+ The pre-processing pipeline prepares input data for the model by:
+ 1. Renaming features to match pretrained configurations.
+ 2. Normalizing input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Appending a newline character to the task description for tokenizer compatibility.
+ 5. Tokenizing the text prompt using the PaliGemma tokenizer.
+ 6. Moving all data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving data to the CPU.
+ 2. Unnormalizing the output features to their original scale.
+
+ Args:
+ config: The configuration object for the PI0 policy.
+ dataset_stats: A dictionary of statistics for normalization.
+ preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
+ postprocessor_kwargs: Additional arguments for the post-processor pipeline.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ # Add remaining processors
+ input_steps: list[ProcessorStep] = [
+ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
+ AddBatchDimensionProcessorStep(),
+ Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
+ TokenizerProcessorStep(
+ tokenizer_name="google/paligemma-3b-pt-224",
+ max_length=config.tokenizer_max_length,
+ padding_side="right",
+ padding="max_length",
+ ),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ ]
+
+ output_steps: list[ProcessorStep] = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/pi05/README.md b/src/lerobot/policies/pi05/README.md
new file mode 100644
index 000000000..2ae69d978
--- /dev/null
+++ b/src/lerobot/policies/pi05/README.md
@@ -0,0 +1,49 @@
+# π₀.₅ (pi05)
+
+This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
+It is designed as a **Vision-Language-Action model with open-world generalization**.
+
+---
+
+## Model Overview
+
+| Feature | π₀ | π₀.₅ |
+| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
+| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
+| AdaRMS | Not used | Used in action expert |
+| Tokenizer Length | 48 tokens | 200 tokens |
+| Discrete State Input | False (Uses `state_proj` layer) | True |
+| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
+
+---
+
+## Citation
+
+If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
+
+```bibtex
+@misc{openpi2024,
+ author = {Physical Intelligence Lab},
+ title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
+ year = {2024},
+ publisher = {GitHub},
+ howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
+ license = {Apache-2.0}
+}
+
+@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
+ title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
+ author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
+ year = {2025},
+ eprint = {2504.16054},
+ archivePrefix= {arXiv},
+ primaryClass = {cs.LG},
+ url = {https://arxiv.org/abs/2504.16054},
+}
+```
+
+---
+
+## License
+
+This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
diff --git a/src/lerobot/policies/pi05/__init__.py b/src/lerobot/policies/pi05/__init__.py
new file mode 100644
index 000000000..4f9a9de4a
--- /dev/null
+++ b/src/lerobot/policies/pi05/__init__.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_pi05 import PI05Config
+from .modeling_pi05 import PI05Policy
+from .processor_pi05 import make_pi05_pre_post_processors
+
+__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py
new file mode 100644
index 000000000..7c1e950b0
--- /dev/null
+++ b/src/lerobot/policies/pi05/configuration_pi05.py
@@ -0,0 +1,153 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
+from lerobot.optim.optimizers import AdamWConfig
+from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+
+
+@PreTrainedConfig.register_subclass("pi05")
+@dataclass
+class PI05Config(PreTrainedConfig):
+ paligemma_variant: str = "gemma_2b"
+ action_expert_variant: str = "gemma_300m"
+ dtype: str = "float32" # Options: "bfloat16", "float32"
+
+ n_obs_steps: int = 1
+ chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
+ n_action_steps: int = 50 # Number of action steps to execute
+
+ # Shorter state and action vectors will be padded to these dimensions
+ max_state_dim: int = 32
+ max_action_dim: int = 32
+
+ # Flow matching parameters: see openpi `PI0Pytorch`
+ num_inference_steps: int = 10
+ time_sampling_beta_alpha: float = 1.5
+ time_sampling_beta_beta: float = 1.0
+ time_sampling_scale: float = 0.999
+ time_sampling_offset: float = 0.001
+ min_period: float = 4e-3
+ max_period: float = 4.0
+
+ image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
+
+ # Add empty images. Used to add empty cameras when no image features are present.
+ empty_cameras: int = 0
+
+ tokenizer_max_length: int = 200 # see openpi `__post_init__`
+
+ normalization_mapping: dict[str, NormalizationMode] = field(
+ default_factory=lambda: {
+ "VISUAL": NormalizationMode.IDENTITY,
+ "STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
+ "ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
+ }
+ )
+
+ # Training settings
+ gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
+ compile_model: bool = False # Whether to use torch.compile for model optimization
+ compile_mode: str = "max-autotune" # Torch compile mode
+ device: str | None = None # Device to use for the model (None = auto-detect)
+
+ # Optimizer settings: see openpi `AdamW`
+ optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
+ optimizer_betas: tuple[float, float] = (0.9, 0.95)
+ optimizer_eps: float = 1e-8
+ optimizer_weight_decay: float = 0.01
+ optimizer_grad_clip_norm: float = 1.0
+
+ # Scheduler settings: see openpi `CosineDecaySchedule`
+ scheduler_warmup_steps: int = 1_000
+ scheduler_decay_steps: int = 30_000
+ scheduler_decay_lr: float = 2.5e-6
+
+ tokenizer_max_length: int = 200 # see openpi `__post_init__`
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ # Validate configuration
+ if self.n_action_steps > self.chunk_size:
+ raise ValueError(
+ f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
+ )
+
+ if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
+ raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
+
+ if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
+ raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
+
+ if self.dtype not in ["bfloat16", "float32"]:
+ raise ValueError(f"Invalid dtype: {self.dtype}")
+
+ def validate_features(self) -> None:
+ """Validate and set up input/output features."""
+ for i in range(self.empty_cameras):
+ key = f"observation.images.empty_camera_{i}"
+ empty_camera = PolicyFeature(
+ type=FeatureType.VISUAL,
+ shape=(3, *self.image_resolution), # Use configured image resolution
+ )
+ self.input_features[key] = empty_camera
+
+ if "observation.state" not in self.input_features:
+ state_feature = PolicyFeature(
+ type=FeatureType.STATE,
+ shape=(self.max_state_dim,), # Padded to max_state_dim
+ )
+ self.input_features["observation.state"] = state_feature
+
+ if "action" not in self.output_features:
+ action_feature = PolicyFeature(
+ type=FeatureType.ACTION,
+ shape=(self.max_action_dim,), # Padded to max_action_dim
+ )
+ self.output_features["action"] = action_feature
+
+ def get_optimizer_preset(self) -> AdamWConfig:
+ return AdamWConfig(
+ lr=self.optimizer_lr,
+ betas=self.optimizer_betas,
+ eps=self.optimizer_eps,
+ weight_decay=self.optimizer_weight_decay,
+ grad_clip_norm=self.optimizer_grad_clip_norm,
+ )
+
+ def get_scheduler_preset(self):
+ return CosineDecayWithWarmupSchedulerConfig(
+ peak_lr=self.optimizer_lr,
+ decay_lr=self.scheduler_decay_lr,
+ num_warmup_steps=self.scheduler_warmup_steps,
+ num_decay_steps=self.scheduler_decay_steps,
+ )
+
+ @property
+ def observation_delta_indices(self) -> None:
+ return None
+
+ @property
+ def action_delta_indices(self) -> list:
+ return list(range(self.chunk_size))
+
+ @property
+ def reward_delta_indices(self) -> None:
+ return None
diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py
new file mode 100644
index 000000000..93ca5fa82
--- /dev/null
+++ b/src/lerobot/policies/pi05/modeling_pi05.py
@@ -0,0 +1,1163 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import builtins
+import logging
+import math
+from collections import deque
+from pathlib import Path
+from typing import TYPE_CHECKING, Literal
+
+import torch
+import torch.nn.functional as F # noqa: N812
+from torch import Tensor, nn
+
+from lerobot.utils.import_utils import _transformers_available
+
+# Conditional import for type checking and lazy loading
+if TYPE_CHECKING or _transformers_available:
+ from transformers.models.auto import CONFIG_MAPPING
+ from transformers.models.gemma import modeling_gemma
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
+else:
+ CONFIG_MAPPING = None
+ modeling_gemma = None
+ GemmaForCausalLM = None
+ PaliGemmaForConditionalGeneration = None
+
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.policies.pi05.configuration_pi05 import PI05Config
+from lerobot.policies.pretrained import PreTrainedPolicy, T
+from lerobot.utils.constants import (
+ ACTION,
+ OBS_LANGUAGE_ATTENTION_MASK,
+ OBS_LANGUAGE_TOKENS,
+ OPENPI_ATTENTION_MASK_VALUE,
+)
+
+
+def get_safe_dtype(target_dtype, device_type):
+ """Get a safe dtype for the given device type."""
+ if device_type == "mps" and target_dtype == torch.float64:
+ return torch.float32
+ if device_type == "cpu":
+ # CPU doesn't support bfloat16, use float32 instead
+ if target_dtype == torch.bfloat16:
+ return torch.float32
+ if target_dtype == torch.float64:
+ return torch.float64
+ return target_dtype
+
+
+def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
+ time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
+) -> Tensor:
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
+ if dimension % 2 != 0:
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
+
+ if time.ndim != 1:
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
+
+ dtype = get_safe_dtype(torch.float64, device.type)
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
+ period = min_period * (max_period / min_period) ** fraction
+
+ # Compute the outer product
+ scaling_factor = 1.0 / period * 2 * math.pi
+ sin_input = scaling_factor[None, :] * time[:, None]
+ return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
+
+
+def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
+ alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
+ beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
+ dist = torch.distributions.Beta(alpha_t, beta_t)
+ return dist.sample((bsize,))
+
+
+def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
+ """Copied from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
+ it and 0 where it shares the same attention mask as the previous token.
+ """
+ if att_masks.ndim != 2:
+ raise ValueError(att_masks.ndim)
+ if pad_masks.ndim != 2:
+ raise ValueError(pad_masks.ndim)
+
+ cumsum = torch.cumsum(att_masks, dim=1)
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
+ return att_2d_masks & pad_2d_masks
+
+
+def pad_vector(vector, new_dim):
+ """Pad the last dimension of a vector to new_dim with zeros.
+
+ Can be (batch_size x sequence_length x features_dimension)
+ or (batch_size x features_dimension)
+ """
+ if vector.shape[-1] >= new_dim:
+ return vector
+ return F.pad(vector, (0, new_dim - vector.shape[-1]))
+
+
+def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
+ images: torch.Tensor,
+ height: int,
+ width: int,
+ mode: str = "bilinear",
+) -> torch.Tensor:
+ """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
+ by padding with black. If the image is float32, it must be in the range [-1, 1].
+
+ Args:
+ images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
+ height: Target height
+ width: Target width
+ mode: Interpolation mode ('bilinear', 'nearest', etc.)
+
+ Returns:
+ Resized and padded tensor with same shape format as input
+ """
+ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
+ if images.shape[-1] <= 4: # Assume channels-last format
+ channels_last = True
+ if images.dim() == 3:
+ images = images.unsqueeze(0) # Add batch dimension
+ images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
+ else:
+ channels_last = False
+ if images.dim() == 3:
+ images = images.unsqueeze(0) # Add batch dimension
+
+ batch_size, channels, cur_height, cur_width = images.shape
+
+ # Calculate resize ratio
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+
+ # Resize
+ resized_images = F.interpolate(
+ images,
+ size=(resized_height, resized_width),
+ mode=mode,
+ align_corners=False if mode == "bilinear" else None,
+ )
+
+ # Handle dtype-specific clipping
+ if images.dtype == torch.uint8:
+ resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
+ elif images.dtype == torch.float32:
+ resized_images = resized_images.clamp(-1.0, 1.0)
+ else:
+ raise ValueError(f"Unsupported image dtype: {images.dtype}")
+
+ # Calculate padding
+ pad_h0, remainder_h = divmod(height - resized_height, 2)
+ pad_h1 = pad_h0 + remainder_h
+ pad_w0, remainder_w = divmod(width - resized_width, 2)
+ pad_w1 = pad_w0 + remainder_w
+
+ # Pad
+ constant_value = 0 if images.dtype == torch.uint8 else -1.0
+ padded_images = F.pad(
+ resized_images,
+ (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
+ mode="constant",
+ value=constant_value,
+ )
+
+ # Convert back to original format if needed
+ if channels_last:
+ padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
+
+ return padded_images
+
+
+# Define the complete layer computation function for gradient checkpointing
+def compute_layer_complete(
+ layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
+):
+ models = [paligemma.language_model, gemma_expert.model]
+ query_states = []
+ key_states = []
+ value_states = []
+ gates = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
+ gates.append(gate)
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ query_states.append(query_state)
+ key_states.append(key_state)
+ value_states.append(value_state)
+ # Concatenate and process attention
+ query_states = torch.cat(query_states, dim=2)
+ key_states = torch.cat(key_states, dim=2)
+ value_states = torch.cat(value_states, dim=2)
+ dummy_tensor = torch.zeros(
+ query_states.shape[0],
+ query_states.shape[2],
+ query_states.shape[-1],
+ device=query_states.device,
+ dtype=query_states.dtype,
+ )
+ cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
+ query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, unsqueeze_dim=1
+ )
+ batch_size = query_states.shape[0]
+ scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
+ # Attention computation
+ att_output, _ = modeling_gemma.eager_attention_forward(
+ paligemma.language_model.layers[layer_idx].self_attn,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling,
+ )
+ # Get head_dim from the current layer, not from the model
+ head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
+ att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
+ # Process layer outputs
+ outputs_embeds = []
+ start_pos = 0
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ end_pos = start_pos + hidden_states.shape[1]
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
+ out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
+ # first residual
+ out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
+ after_first_residual = out_emb.clone()
+ out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
+ # Convert to bfloat16 if the next layer (mlp) uses bfloat16
+ if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
+ out_emb = out_emb.to(dtype=torch.bfloat16)
+ out_emb = layer.mlp(out_emb)
+ # second residual
+ out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
+ outputs_embeds.append(out_emb)
+ start_pos = end_pos
+ return outputs_embeds
+
+
+class GemmaConfig: # see openpi `gemma.py: Config`
+ """Configuration for Gemma model variants."""
+
+ def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
+ self.width = width
+ self.depth = depth
+ self.mlp_dim = mlp_dim
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+
+
+def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config`
+ """Returns config for specified gemma variant."""
+ if variant == "gemma_300m":
+ return GemmaConfig(
+ width=1024,
+ depth=18,
+ mlp_dim=4096,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ elif variant == "gemma_2b":
+ return GemmaConfig(
+ width=2048,
+ depth=18,
+ mlp_dim=16_384,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ else:
+ raise ValueError(f"Unknown variant: {variant}")
+
+
+class PaliGemmaWithExpertModel(
+ nn.Module
+): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
+ """PaliGemma model with action expert for PI05."""
+
+ def __init__(
+ self,
+ vlm_config,
+ action_expert_config,
+ use_adarms=None,
+ precision: Literal["bfloat16", "float32"] = "bfloat16",
+ ):
+ if use_adarms is None:
+ use_adarms = [False, False]
+ super().__init__()
+
+ vlm_config_hf = CONFIG_MAPPING["paligemma"]()
+ vlm_config_hf._vocab_size = 257152 # noqa: SLF001
+ vlm_config_hf.image_token_index = 257152
+ vlm_config_hf.text_config.hidden_size = vlm_config.width
+ vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
+ vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
+ vlm_config_hf.text_config.head_dim = vlm_config.head_dim
+ vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
+ vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
+ vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
+ vlm_config_hf.text_config.torch_dtype = "float32"
+ vlm_config_hf.text_config.vocab_size = 257152
+ vlm_config_hf.text_config.use_adarms = use_adarms[0]
+ vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
+ vlm_config_hf.vision_config.intermediate_size = 4304
+ vlm_config_hf.vision_config.projection_dim = 2048
+ vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
+ vlm_config_hf.vision_config.torch_dtype = "float32"
+
+ action_expert_config_hf = CONFIG_MAPPING["gemma"](
+ head_dim=action_expert_config.head_dim,
+ hidden_size=action_expert_config.width,
+ intermediate_size=action_expert_config.mlp_dim,
+ num_attention_heads=action_expert_config.num_heads,
+ num_hidden_layers=action_expert_config.depth,
+ num_key_value_heads=action_expert_config.num_kv_heads,
+ vocab_size=257152,
+ hidden_activation="gelu_pytorch_tanh",
+ torch_dtype="float32",
+ use_adarms=use_adarms[1],
+ adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
+ )
+
+ self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
+ self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
+ self.gemma_expert.model.embed_tokens = None
+
+ self.to_bfloat16_for_selected_params(precision)
+
+ def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
+ if precision == "bfloat16":
+ self.to(dtype=torch.bfloat16)
+ elif precision == "float32":
+ self.to(dtype=torch.float32)
+ return
+ else:
+ raise ValueError(f"Invalid precision: {precision}")
+
+ params_to_keep_float32 = [
+ "vision_tower.vision_model.embeddings.patch_embedding.weight",
+ "vision_tower.vision_model.embeddings.patch_embedding.bias",
+ "vision_tower.vision_model.embeddings.position_embedding.weight",
+ "input_layernorm",
+ "post_attention_layernorm",
+ "model.norm",
+ ]
+
+ for name, param in self.named_parameters():
+ if any(selector in name for selector in params_to_keep_float32):
+ param.data = param.data.to(dtype=torch.float32)
+
+ def embed_image(self, image: torch.Tensor):
+ return self.paligemma.model.get_image_features(image)
+
+ def embed_language_tokens(self, tokens: torch.Tensor):
+ return self.paligemma.language_model.embed_tokens(tokens)
+
+ def forward(
+ self,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ inputs_embeds: list[torch.FloatTensor] | None = None,
+ use_cache: bool | None = None,
+ adarms_cond: list[torch.Tensor] | None = None,
+ ):
+ if adarms_cond is None:
+ adarms_cond = [None, None]
+ if inputs_embeds[1] is None:
+ prefix_output = self.paligemma.language_model.forward(
+ inputs_embeds=inputs_embeds[0],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
+ )
+ prefix_past_key_values = prefix_output.past_key_values
+ prefix_output = prefix_output.last_hidden_state
+ suffix_output = None
+ elif inputs_embeds[0] is None:
+ suffix_output = self.gemma_expert.model.forward(
+ inputs_embeds=inputs_embeds[1],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
+ )
+ suffix_output = suffix_output.last_hidden_state
+ prefix_output = None
+ prefix_past_key_values = None
+ else:
+ models = [self.paligemma.language_model, self.gemma_expert.model]
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
+
+ # Check if gradient checkpointing is enabled for any of the models
+ use_gradient_checkpointing = (
+ hasattr(self.gemma_expert.model, "gradient_checkpointing")
+ and self.gemma_expert.model.gradient_checkpointing
+ and self.training
+ ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
+
+ # Process all layers with gradient checkpointing if enabled
+ for layer_idx in range(num_layers):
+ if use_gradient_checkpointing:
+ inputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_layer_complete,
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ paligemma=self.paligemma,
+ gemma_expert=self.gemma_expert,
+ )
+ else:
+ inputs_embeds = compute_layer_complete(
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ paligemma=self.paligemma,
+ gemma_expert=self.gemma_expert,
+ )
+
+ # final norm
+ def compute_final_norms(inputs_embeds, adarms_cond):
+ outputs_embeds = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
+ outputs_embeds.append(out_emb)
+ return outputs_embeds
+
+ # Apply gradient checkpointing to final norm if enabled
+ if use_gradient_checkpointing:
+ outputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_final_norms,
+ inputs_embeds,
+ adarms_cond,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ )
+ else:
+ outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
+
+ prefix_output = outputs_embeds[0]
+ suffix_output = outputs_embeds[1]
+ prefix_past_key_values = None
+
+ return [prefix_output, suffix_output], prefix_past_key_values
+
+
+class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
+ """Core PI05 PyTorch model."""
+
+ def __init__(self, config: PI05Config):
+ super().__init__()
+ self.config = config
+
+ paligemma_config = get_gemma_config(config.paligemma_variant)
+ action_expert_config = get_gemma_config(config.action_expert_variant)
+
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
+ paligemma_config,
+ action_expert_config,
+ use_adarms=[False, True],
+ precision=config.dtype,
+ )
+
+ self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
+ self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
+
+ self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
+ self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
+
+ # Initialize gradient checkpointing flag
+ self.gradient_checkpointing_enabled = False
+
+ # Compile model if requested
+ if config.compile_model:
+ torch.set_float32_matmul_precision("high")
+ self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
+
+ msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
+
+ try:
+ from transformers.models.siglip import check
+
+ if not check.check_whether_transformers_replace_is_installed_correctly():
+ raise ValueError(msg)
+ except ImportError:
+ raise ValueError(msg) from None
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.gradient_checkpointing_enabled = True
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
+ logging.info("Enabled gradient checkpointing for PI05Pytorch model")
+
+ def gradient_checkpointing_disable(self):
+ """Disable gradient checkpointing."""
+ self.gradient_checkpointing_enabled = False
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
+ logging.info("Disabled gradient checkpointing for PI05Pytorch model")
+
+ def _apply_checkpoint(self, func, *args, **kwargs):
+ """Helper method to apply gradient checkpointing if enabled."""
+ if self.gradient_checkpointing_enabled and self.training:
+ return torch.utils.checkpoint.checkpoint(
+ func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
+ )
+ return func(*args, **kwargs)
+
+ def _prepare_attention_masks_4d(self, att_2d_masks):
+ """Helper method to prepare 4D attention masks for transformer."""
+ att_2d_masks_4d = att_2d_masks[:, None, :, :]
+ return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
+
+ def sample_noise(self, shape, device):
+ return torch.normal(
+ mean=0.0,
+ std=1.0,
+ size=shape,
+ dtype=torch.float32,
+ device=device,
+ )
+
+ def sample_time(self, bsize, device):
+ time_beta = sample_beta(
+ self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
+ )
+ time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
+ return time.to(dtype=torch.float32, device=device)
+
+ def embed_prefix(
+ self, images, img_masks, tokens, masks
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Embed images with SigLIP and language tokens with embedding layer."""
+ embs = []
+ pad_masks = []
+ att_masks = []
+
+ # Process images
+ for img, img_mask in zip(images, img_masks, strict=True):
+
+ def image_embed_func(img):
+ return self.paligemma_with_expert.embed_image(img)
+
+ img_emb = self._apply_checkpoint(image_embed_func, img)
+ bsize, num_img_embs = img_emb.shape[:2]
+
+ embs.append(img_emb)
+ pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
+ att_masks += [0] * num_img_embs
+
+ # Process language tokens
+ def lang_embed_func(tokens):
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
+ lang_emb_dim = lang_emb.shape[-1]
+ return lang_emb * math.sqrt(lang_emb_dim)
+
+ lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
+ embs.append(lang_emb)
+ pad_masks.append(masks)
+
+ num_lang_embs = lang_emb.shape[1]
+ att_masks += [0] * num_lang_embs
+
+ embs = torch.cat(embs, dim=1)
+ pad_masks = torch.cat(pad_masks, dim=1)
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
+
+ bsize = pad_masks.shape[0]
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
+
+ return embs, pad_masks, att_masks
+
+ def embed_suffix(self, noisy_actions, timestep):
+ """Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
+ embs = []
+ pad_masks = []
+ att_masks = []
+
+ # Embed timestep using sine-cosine positional encoding
+ time_emb = create_sinusoidal_pos_embedding(
+ timestep,
+ self.action_in_proj.out_features,
+ min_period=self.config.min_period,
+ max_period=self.config.max_period,
+ device=timestep.device,
+ )
+ time_emb = time_emb.type(dtype=timestep.dtype)
+
+ # Fuse timestep + action information using an MLP
+ def action_proj_func(noisy_actions):
+ return self.action_in_proj(noisy_actions)
+
+ action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
+
+ def time_mlp_func(time_emb):
+ x = self.time_mlp_in(time_emb)
+ x = F.silu(x)
+ x = self.time_mlp_out(x)
+ return F.silu(x)
+
+ time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
+ action_time_emb = action_emb
+ adarms_cond = time_emb
+
+ embs.append(action_time_emb)
+ bsize, action_time_dim = action_time_emb.shape[:2]
+ action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
+ pad_masks.append(action_time_mask)
+
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
+ att_masks += [1] + ([0] * (self.config.chunk_size - 1))
+
+ embs = torch.cat(embs, dim=1)
+ pad_masks = torch.cat(pad_masks, dim=1)
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
+
+ return embs, pad_masks, att_masks, adarms_cond
+
+ def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
+ """Do a full training forward pass and compute the loss."""
+ if noise is None:
+ noise = self.sample_noise(actions.shape, actions.device)
+
+ if time is None:
+ time = self.sample_time(actions.shape[0], actions.device)
+
+ time_expanded = time[:, None, None]
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
+ u_t = noise - actions
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
+
+ if (
+ self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
+ == torch.bfloat16
+ ):
+ suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
+ prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
+
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
+
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
+
+ att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
+
+ def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
+ attention_mask=att_2d_masks_4d,
+ position_ids=position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, suffix_embs],
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
+ )
+ return suffix_out
+
+ suffix_out = self._apply_checkpoint(
+ forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
+ )
+
+ suffix_out = suffix_out[:, -self.config.chunk_size :]
+ suffix_out = suffix_out.to(dtype=torch.float32)
+
+ def action_out_proj_func(suffix_out):
+ return self.action_out_proj(suffix_out)
+
+ v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
+
+ return F.mse_loss(u_t, v_t, reduction="none")
+
+ @torch.no_grad() # see openpi `sample_actions` (slightly adapted)
+ def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
+ """Do a full inference forward and compute the action."""
+ if num_steps is None:
+ num_steps = self.config.num_inference_steps
+
+ bsize = tokens.shape[0]
+ device = tokens.device
+
+ if noise is None:
+ # Sample noise with padded dimension as expected by action_in_proj
+ actions_shape = (
+ bsize,
+ self.config.chunk_size,
+ self.config.max_action_dim,
+ ) # Use config max_action_dim for internal processing
+ noise = self.sample_noise(actions_shape, device)
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
+
+ prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
+ self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
+
+ _, past_key_values = self.paligemma_with_expert.forward(
+ attention_mask=prefix_att_2d_masks_4d,
+ position_ids=prefix_position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, None],
+ use_cache=True,
+ )
+
+ dt = -1.0 / num_steps
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
+
+ x_t = noise
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
+ while time >= -dt / 2:
+ expanded_time = time.expand(bsize)
+ v_t = self.denoise_step(
+ prefix_pad_masks,
+ past_key_values,
+ x_t,
+ expanded_time,
+ )
+ x_t = x_t + dt * v_t
+ time += dt
+
+ return x_t
+
+ def denoise_step(
+ self,
+ prefix_pad_masks,
+ past_key_values,
+ x_t,
+ timestep,
+ ):
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep)
+
+ suffix_len = suffix_pad_masks.shape[1]
+ batch_size = prefix_pad_masks.shape[0]
+ prefix_len = prefix_pad_masks.shape[1]
+
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
+
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
+
+ full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
+ self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
+
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
+ attention_mask=full_att_2d_masks_4d,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=[None, suffix_embs],
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
+ )
+
+ suffix_out = outputs_embeds[1]
+ suffix_out = suffix_out[:, -self.config.chunk_size :]
+ suffix_out = suffix_out.to(dtype=torch.float32)
+ return self.action_out_proj(suffix_out)
+
+
+class PI05Policy(PreTrainedPolicy):
+ """PI05 Policy for LeRobot."""
+
+ config_class = PI05Config
+ name = "pi05"
+
+ def __init__(
+ self,
+ config: PI05Config,
+ ):
+ """
+ Args:
+ config: Policy configuration class instance.
+ """
+ super().__init__(config)
+ config.validate_features()
+ self.config = config
+
+ # Initialize the core PI05 model
+ self.model = PI05Pytorch(config)
+
+ # Enable gradient checkpointing if requested
+ if config.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable()
+
+ self.model.to(config.device)
+
+ self.reset()
+
+ @classmethod
+ def from_pretrained(
+ cls: builtins.type[T],
+ pretrained_name_or_path: str | Path,
+ *,
+ config: PreTrainedConfig | None = None,
+ force_download: bool = False,
+ resume_download: bool | None = None,
+ proxies: dict | None = None,
+ token: str | bool | None = None,
+ cache_dir: str | Path | None = None,
+ local_files_only: bool = False,
+ revision: str | None = None,
+ strict: bool = True,
+ **kwargs,
+ ) -> T:
+ """Override the from_pretrained method to handle key remapping and display important disclaimer."""
+ print(
+ "The PI05 model is a direct port of the OpenPI implementation. \n"
+ "This implementation follows the original OpenPI structure for compatibility. \n"
+ "Original implementation: https://github.com/Physical-Intelligence/openpi"
+ )
+ if pretrained_name_or_path is None:
+ raise ValueError("pretrained_name_or_path is required")
+
+ # Use provided config if available, otherwise create default config
+ if config is None:
+ config = PreTrainedConfig.from_pretrained(
+ pretrained_name_or_path=pretrained_name_or_path,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ revision=revision,
+ **kwargs,
+ )
+
+ # Initialize model without loading weights
+ # Check if dataset_stats were provided in kwargs
+ model = cls(config, **kwargs)
+
+ # Now manually load and remap the state dict
+ try:
+ # Try to load the pytorch_model.bin or model.safetensors file
+ print(f"Loading model from: {pretrained_name_or_path}")
+ try:
+ from transformers.utils import cached_file
+
+ # Try safetensors first
+ resolved_file = cached_file(
+ pretrained_name_or_path,
+ "model.safetensors",
+ cache_dir=kwargs.get("cache_dir"),
+ force_download=kwargs.get("force_download", False),
+ resume_download=kwargs.get("resume_download"),
+ proxies=kwargs.get("proxies"),
+ use_auth_token=kwargs.get("use_auth_token"),
+ revision=kwargs.get("revision"),
+ local_files_only=kwargs.get("local_files_only", False),
+ )
+ from safetensors.torch import load_file
+
+ original_state_dict = load_file(resolved_file)
+ print("✓ Loaded state dict from model.safetensors")
+ except Exception as e:
+ print(f"Could not load state dict from remote files: {e}")
+ print("Returning model without loading pretrained weights")
+ return model
+
+ # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
+ fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
+
+ # Then add "model." prefix for all keys that don't already have it
+ remapped_state_dict = {}
+ remap_count = 0
+
+ for key, value in fixed_state_dict.items():
+ if not key.startswith("model."):
+ new_key = f"model.{key}"
+ remapped_state_dict[new_key] = value
+ remap_count += 1
+ if remap_count <= 10: # Only print first 10 to avoid spam
+ print(f"Remapped: {key} -> {new_key}")
+ else:
+ remapped_state_dict[key] = value
+
+ if remap_count > 0:
+ print(f"Remapped {remap_count} state dict keys")
+
+ # Load the remapped state dict into the model
+ missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
+
+ if missing_keys:
+ print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
+ if len(missing_keys) <= 5:
+ for key in missing_keys:
+ print(f" - {key}")
+ else:
+ for key in missing_keys[:5]:
+ print(f" - {key}")
+ print(f" ... and {len(missing_keys) - 5} more")
+
+ if unexpected_keys:
+ print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
+ if len(unexpected_keys) <= 5:
+ for key in unexpected_keys:
+ print(f" - {key}")
+ else:
+ for key in unexpected_keys[:5]:
+ print(f" - {key}")
+ print(f" ... and {len(unexpected_keys) - 5} more")
+
+ if not missing_keys and not unexpected_keys:
+ print("All keys loaded successfully!")
+
+ except Exception as e:
+ print(f"Warning: Could not remap state dict keys: {e}")
+
+ return model
+
+ def _fix_pytorch_state_dict_keys(
+ self, state_dict, model_config
+ ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys`
+ """Fix state dict keys to match current model architecture."""
+ import re
+
+ fixed_state_dict = {}
+
+ for key, value in state_dict.items():
+ new_key = key
+
+ # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias
+ # For gemma expert layers
+ if re.match(
+ r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
+ key,
+ ):
+ # Check if the model actually has adaRMS enabled for the expert
+ expert_uses_adarms = getattr(
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
+ )
+ if expert_uses_adarms:
+ logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
+ continue
+
+ if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
+ # Check if the model actually has adaRMS enabled for the expert
+ expert_uses_adarms = getattr(
+ self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
+ )
+ if expert_uses_adarms:
+ logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
+ continue
+
+ # Handle MLP naming changes for pi05
+ # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_*
+ if key.startswith("action_time_mlp_in."):
+ new_key = key.replace("action_time_mlp_in.", "time_mlp_in.")
+ elif key.startswith("action_time_mlp_out."):
+ new_key = key.replace("action_time_mlp_out.", "time_mlp_out.")
+ # Also handle state_proj which shouldn't exist in pi05
+ if key.startswith("state_proj."):
+ logging.warning(f"Skipping state_proj key in pi05 mode: {key}")
+ continue
+
+ # Handle vision tower embedding layer potential differences
+ if "patch_embedding" in key:
+ # Some checkpoints might have this, but current model expects different structure
+ logging.warning(f"Vision embedding key might need handling: {key}")
+
+ fixed_state_dict[new_key] = value
+
+ return fixed_state_dict
+
+ def get_optim_params(self) -> dict:
+ return self.parameters()
+
+ def reset(self):
+ """Reset internal state - called when environment resets."""
+ self._action_queue = deque(maxlen=self.config.n_action_steps)
+ self._queues = {
+ ACTION: deque(maxlen=self.config.n_action_steps),
+ }
+
+ def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
+ """Preprocess images for the model.
+
+ Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1].
+ PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1].
+ """
+ images = []
+ img_masks = []
+
+ # Get device from model parameters
+ device = next(self.parameters()).device
+
+ present_img_keys = [key for key in self.config.image_features if key in batch]
+ missing_img_keys = [key for key in self.config.image_features if key not in batch]
+
+ if len(present_img_keys) == 0:
+ raise ValueError(
+ f"All image features are missing from the batch. At least one expected. "
+ f"(batch: {batch.keys()}) (image_features: {self.config.image_features})"
+ )
+
+ # Preprocess image features present in the batch
+ for key in present_img_keys:
+ img = batch[key]
+
+ # Ensure tensor is on the same device as the model
+ if img.device != device:
+ img = img.to(device)
+
+ # Ensure float32 dtype for consistency
+ if img.dtype != torch.float32:
+ img = img.to(torch.float32)
+
+ # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats
+ is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1
+
+ if is_channels_first:
+ # Convert [B, C, H, W] to [B, H, W, C] for processing
+ img = img.permute(0, 2, 3, 1)
+
+ # from openpi preprocess_observation_pytorch: Resize with padding if needed
+ if img.shape[1:3] != self.config.image_resolution:
+ img = resize_with_pad_torch(img, *self.config.image_resolution)
+
+ # Normalize from [0,1] to [-1,1] as expected by siglip
+ img = img * 2.0 - 1.0
+
+ # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
+ if is_channels_first:
+ img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
+
+ images.append(img)
+ # Create mask (all ones for real images)
+ bsize = img.shape[0]
+ mask = torch.ones(bsize, dtype=torch.bool, device=device)
+ img_masks.append(mask)
+
+ # Create image features not present in the batch as fully 0 padded images
+ for _num_empty_cameras in range(len(missing_img_keys)):
+ img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP
+ mask = torch.zeros_like(mask) # Mask is zero for empty cameras
+ images.append(img)
+ img_masks.append(mask)
+
+ return images, img_masks
+
+ def prepare_action(self, batch):
+ """Pad action"""
+ actions = pad_vector(batch[ACTION], self.config.max_action_dim)
+ return actions
+
+ @torch.no_grad()
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
+ """Select a single action given environment observations."""
+ self.eval()
+
+ # Action queue logic for n_action_steps > 1
+ if len(self._action_queue) == 0:
+ actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
+ # Transpose to get shape (n_action_steps, batch_size, action_dim)
+ self._action_queue.extend(actions.transpose(0, 1))
+
+ return self._action_queue.popleft()
+
+ @torch.no_grad()
+ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
+ """Predict a chunk of actions given environment observations."""
+ self.eval()
+
+ # Prepare inputs
+ images, img_masks = self._preprocess_images(batch)
+ tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
+
+ # Sample actions using the model (no separate state needed for PI05)
+ actions = self.model.sample_actions(images, img_masks, tokens, masks)
+
+ # Unpad actions to actual action dimension
+ original_action_dim = self.config.output_features[ACTION].shape[0]
+ actions = actions[:, :, :original_action_dim]
+
+ return actions
+
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
+ """Run the batch through the model and compute the loss for training."""
+
+ # Prepare inputs
+ images, img_masks = self._preprocess_images(batch)
+ tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
+
+ actions = self.prepare_action(batch)
+
+ # Compute loss (no separate state needed for PI05)
+ losses = self.model.forward(images, img_masks, tokens, masks, actions)
+
+ # Truncate losses to actual action dimensions
+ original_action_dim = self.config.output_features[ACTION].shape[0]
+ losses = losses[:, :, :original_action_dim]
+
+ loss = losses.mean()
+
+ loss_dict = {
+ "loss": loss.item(),
+ "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
+ }
+
+ return loss, loss_dict
diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py
new file mode 100644
index 000000000..e29bc4c23
--- /dev/null
+++ b/src/lerobot/policies/pi05/processor_pi05.py
@@ -0,0 +1,171 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Any
+
+import numpy as np
+import torch
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.policies.pi05.configuration_pi05 import PI05Config
+from lerobot.policies.pi05.modeling_pi05 import pad_vector
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ ProcessorStep,
+ ProcessorStepRegistry,
+ RenameObservationsProcessorStep,
+ TokenizerProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.processor.core import EnvTransition, TransitionKey
+from lerobot.utils.constants import (
+ OBS_STATE,
+ POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ POLICY_PREPROCESSOR_DEFAULT_NAME,
+)
+
+
+@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
+@dataclass
+class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
+ """
+ Processor step to prepare the state and tokenize the language input.
+ """
+
+ max_state_dim: int = 32
+ task_key: str = "task"
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ transition = transition.copy()
+
+ state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
+ if state is None:
+ raise ValueError("State is required for PI05")
+ tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
+ if tasks is None:
+ raise ValueError("No task found in complementary data")
+
+ # TODO: check if this necessary
+ state = deepcopy(state)
+
+ # Prepare state (pad to max_state_dim)
+ state = pad_vector(state, self.max_state_dim)
+
+ # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
+ # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
+ state_np = state.cpu().numpy()
+ discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+
+ full_prompts = []
+ for i, task in enumerate(tasks):
+ cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
+ state_str = " ".join(map(str, discretized_states[i]))
+ full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
+ full_prompts.append(full_prompt)
+
+ transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
+ # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
+ # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
+ return transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ This step does not alter the feature definitions.
+ """
+ return features
+
+
+def make_pi05_pre_post_processors(
+ config: PI05Config,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the PI0 policy.
+
+ The pre-processing pipeline prepares input data for the model by:
+ 1. Renaming features to match pretrained configurations.
+ 2. Normalizing input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Appending a newline character to the task description for tokenizer compatibility.
+ 5. Tokenizing the text prompt using the PaliGemma tokenizer.
+ 6. Moving all data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving data to the CPU.
+ 2. Unnormalizing the output features to their original scale.
+
+ Args:
+ config: The configuration object for the PI0 policy.
+ dataset_stats: A dictionary of statistics for normalization.
+ preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
+ postprocessor_kwargs: Additional arguments for the post-processor pipeline.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ # Add remaining processors
+ input_steps: list[ProcessorStep] = [
+ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
+ AddBatchDimensionProcessorStep(),
+ # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
+ # because the tokenizer step expects normalized state in [-1, 1] range for discretization
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
+ TokenizerProcessorStep(
+ tokenizer_name="google/paligemma-3b-pt-224",
+ max_length=config.tokenizer_max_length,
+ padding_side="right",
+ padding="max_length",
+ ),
+ DeviceProcessorStep(device=config.device),
+ ]
+
+ output_steps: list[ProcessorStep] = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py
index b72bcd735..cefd4e688 100644
--- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py
+++ b/src/lerobot/policies/pi0fast/configuration_pi0fast.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
@@ -6,6 +22,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
+from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0fast")
@@ -99,7 +116,7 @@ class PI0FASTConfig(PreTrainedConfig):
def validate_features(self) -> None:
for i in range(self.empty_cameras):
- key = f"observation.images.empty_camera_{i}"
+ key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py
index d3e576d1c..102cfb8fa 100644
--- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py
+++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py
@@ -21,17 +21,18 @@
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
+Disclaimer: It is not expected to perform as well as the original implementation.
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--policy.path=lerobot/pi0fast_base \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of training the pi0+FAST neural network with from scratch:
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--policy.type=pi0fast \
--dataset.repo_id=danaaubakirova/koch_test
```
@@ -56,10 +57,9 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
-from lerobot.constants import ACTION, OBS_STATE
-from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.utils.constants import ACTION, OBS_STATE
PRECISION = {
"float16": torch.float16,
@@ -145,14 +145,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
-
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FAST(config)
@@ -162,6 +154,16 @@ class PI0FASTPolicy(PreTrainedPolicy):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ """Override the from_pretrained method to display important disclaimer."""
+ print(
+ "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
+ " It is not expected to perform as well as the original implementation. \n"
+ " Original implementation: https://github.com/Physical-Intelligence/openpi"
+ )
+ return super().from_pretrained(*args, **kwargs)
+
def get_optim_params(self) -> dict:
return self.parameters()
@@ -192,12 +194,12 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
- @torch.no_grad
+ @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")
- @torch.no_grad
+ @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@@ -210,8 +212,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
- batch = self.normalize_inputs(batch)
-
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
@@ -224,8 +224,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
] # self.config.max_action_dim # self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
- actions = self.unnormalize_outputs({"action": actions})["action"]
-
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -238,8 +236,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
- batch = self.normalize_inputs(batch)
- batch = self.normalize_targets(batch)
loss_dict = self.model.forward(batch)
return loss_dict["loss"], loss_dict
@@ -477,6 +473,8 @@ class PI0FAST(nn.Module):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
+ # TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed
+ # AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index'
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side
diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py
new file mode 100644
index 000000000..95b5e541b
--- /dev/null
+++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import torch
+
+from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ RenameObservationsProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+def make_pi0fast_pre_post_processors(
+ config: PI0FASTConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
+
+ The pre-processing pipeline prepares input data for the model by:
+ 1. Renaming features to match pretrained configurations.
+ 2. Normalizing input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Moving all data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving data to the CPU.
+ 2. Unnormalizing the output features to their original scale.
+
+ Args:
+ config: The configuration object for the PI0Fast policy.
+ dataset_stats: A dictionary of statistics for normalization.
+ preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
+ postprocessor_kwargs: Additional arguments for the post-processor pipeline.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
+ AddBatchDimensionProcessorStep(),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py
index d18b798a8..3f5d89ec5 100644
--- a/src/lerobot/policies/pretrained.py
+++ b/src/lerobot/policies/pretrained.py
@@ -12,29 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import builtins
import logging
import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
-from typing import List, Type, TypeVar
+from typing import TypedDict, TypeVar
import packaging
import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
-from safetensors.torch import load_model as load_model_as_safetensor
-from safetensors.torch import save_model as save_model_as_safetensor
+from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
+from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
+from lerobot.policies.utils import log_model_loading_keys
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="PreTrainedPolicy")
+class ActionSelectKwargs(TypedDict, total=False):
+ noise: Tensor | None
+
+
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
Base class for policy models.
@@ -67,7 +73,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
@classmethod
def from_pretrained(
- cls: Type[T],
+ cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
@@ -128,18 +134,26 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
- if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
- load_model_as_safetensor(model, model_file, strict=strict)
- if map_location != "cpu":
- logging.warning(
- "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
- " This means that the model is loaded on 'cpu' first and then copied to the device."
- " This leads to a slower loading time."
- " Please update safetensors to version 0.4.3 or above for improved performance."
- )
- model.to(map_location)
- else:
- safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
+ # Create base kwargs
+ kwargs = {"strict": strict}
+
+ # Add device parameter for newer versions that support it
+ if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
+ kwargs["device"] = map_location
+
+ # Load the model with appropriate kwargs
+ missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
+ log_model_loading_keys(missing_keys, unexpected_keys)
+
+ # For older versions, manually move to device if needed
+ if "device" not in kwargs and map_location != "cpu":
+ logging.warning(
+ "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
+ " This means that the model is loaded on 'cpu' first and then copied to the device."
+ " This leads to a slower loading time."
+ " Please update safetensors to version 0.4.3 or above for improved performance."
+ )
+ model.to(map_location)
return model
@abc.abstractmethod
@@ -172,7 +186,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
raise NotImplementedError
@abc.abstractmethod
- def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
+ def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
Child classes using action chunking should use this method within `select_action` to form the action chunk
@@ -181,7 +195,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
raise NotImplementedError
@abc.abstractmethod
- def select_action(self, batch: dict[str, Tensor]) -> Tensor:
+ def select_action(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
@@ -223,7 +237,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
- self, dataset_repo_id: str, model_type: str, license: str | None, tags: List[str] | None
+ self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
) -> ModelCard:
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
@@ -237,7 +251,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
base_model=base_model,
)
- template_card = files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text()
+ template_card = (
+ files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8")
+ )
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card
diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py
index c57eeeb72..ada12330c 100644
--- a/src/lerobot/policies/sac/configuration_sac.py
+++ b/src/lerobot/policies/sac/configuration_sac.py
@@ -19,8 +19,8 @@ from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
-from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.optim.optimizers import MultiAdamConfig
+from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def is_image_feature(key: str) -> bool:
@@ -139,8 +139,6 @@ class SACConfig(PreTrainedConfig):
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
- # Seed for the online environment
- online_env_seed: int = 10000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
@@ -225,7 +223,7 @@ class SACConfig(PreTrainedConfig):
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
- if "action" not in self.output_features:
+ if ACTION not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py
index 54ea122a8..c7c6798ed 100644
--- a/src/lerobot/policies/sac/modeling_sac.py
+++ b/src/lerobot/policies/sac/modeling_sac.py
@@ -16,8 +16,9 @@
# limitations under the License.
import math
+from collections.abc import Callable
from dataclasses import asdict
-from typing import Callable, Literal
+from typing import Literal
import einops
import numpy as np
@@ -27,10 +28,10 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
-from lerobot.policies.normalize import NormalizeBuffer
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.policies.utils import get_device_from_parameters
+from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
@@ -44,15 +45,13 @@ class SACPolicy(
def __init__(
self,
config: SACConfig | None = None,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__(config)
config.validate_features()
self.config = config
# Determine action dimension and initialize all components
- continuous_action_dim = config.output_features["action"].shape[0]
- self._init_normalization(dataset_stats)
+ continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
@@ -76,7 +75,7 @@ class SACPolicy(
"""Reset the policy"""
pass
- @torch.no_grad
+ @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
@@ -87,8 +86,7 @@ class SACPolicy(
observations_features = None
if self.shared_encoder and self.actor.encoder.has_images:
- # Cache and normalize image features
- observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
+ observations_features = self.actor.encoder.get_cached_image_features(batch)
actions, _, _ = self.actor(batch, observations_features)
@@ -160,7 +158,7 @@ class SACPolicy(
The computed loss tensor
"""
# Extract common components from batch
- actions: Tensor = batch["action"]
+ actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
@@ -390,28 +388,12 @@ class SACPolicy(
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
- def _init_normalization(self, dataset_stats):
- """Initialize input/output normalization modules."""
- self.normalize_inputs = nn.Identity()
- self.normalize_targets = nn.Identity()
- if self.config.dataset_stats is not None:
- params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
- self.normalize_inputs = NormalizeBuffer(
- self.config.input_features, self.config.normalization_mapping, params
- )
- stats = dataset_stats or params
- self.normalize_targets = NormalizeBuffer(
- self.config.output_features, self.config.normalization_mapping, stats
- )
-
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
- self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
+ self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_actor = (
- self.encoder_critic
- if self.shared_encoder
- else SACObservationEncoder(self.config, self.normalize_inputs)
+ self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
@@ -423,9 +405,7 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
- self.critic_ensemble = CriticEnsemble(
- encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
- )
+ self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
@@ -433,9 +413,7 @@ class SACPolicy(
)
for _ in range(self.config.num_critics)
]
- self.critic_target = CriticEnsemble(
- encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
- )
+ self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
@@ -489,10 +467,9 @@ class SACPolicy(
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
- def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
+ def __init__(self, config: SACConfig) -> None:
super().__init__()
self.config = config
- self.input_normalization = input_normalizer
self._init_image_layers()
self._init_state_layers()
self._compute_output_dim()
@@ -537,17 +514,17 @@ class SACObservationEncoder(nn.Module):
)
def _init_state_layers(self) -> None:
- self.has_env = "observation.environment_state" in self.config.input_features
- self.has_state = "observation.state" in self.config.input_features
+ self.has_env = OBS_ENV_STATE in self.config.input_features
+ self.has_state = OBS_STATE in self.config.input_features
if self.has_env:
- dim = self.config.input_features["observation.environment_state"].shape[0]
+ dim = self.config.input_features[OBS_ENV_STATE].shape[0]
self.env_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
if self.has_state:
- dim = self.config.input_features["observation.state"].shape[0]
+ dim = self.config.input_features[OBS_STATE].shape[0]
self.state_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
@@ -567,16 +544,15 @@ class SACObservationEncoder(nn.Module):
def forward(
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
) -> Tensor:
- obs = self.input_normalization(obs)
parts = []
if self.has_images:
if cache is None:
- cache = self.get_cached_image_features(obs, normalize=False)
+ cache = self.get_cached_image_features(obs)
parts.append(self._encode_images(cache, detach))
if self.has_env:
- parts.append(self.env_encoder(obs["observation.environment_state"]))
+ parts.append(self.env_encoder(obs[OBS_ENV_STATE]))
if self.has_state:
- parts.append(self.state_encoder(obs["observation.state"]))
+ parts.append(self.state_encoder(obs[OBS_STATE]))
if parts:
return torch.cat(parts, dim=-1)
@@ -584,7 +560,7 @@ class SACObservationEncoder(nn.Module):
"No parts to concatenate, you should have at least one image or environment state or state"
)
- def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
+ def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
"""Extract and optionally cache image features from observations.
This function processes image observations through the vision encoder once and returns
@@ -596,26 +572,17 @@ class SACObservationEncoder(nn.Module):
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
- Caching these features can provide 2-4x speedup in training and inference
- Normalization behavior:
- - When called from inside forward(): set normalize=False since inputs are already normalized
- - When called from outside forward(): set normalize=True to ensure proper input normalization
-
Usage patterns:
- - Called in select_action() with normalize=True
+ - Called in select_action()
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
- - Called internally by forward() with normalize=False
+ - Called internally by forward()
Args:
obs: Dictionary of observation tensors containing image keys
- normalize: Whether to normalize observations before encoding
- Set to True when calling directly from outside the encoder's forward method
- Set to False when calling from within forward() where inputs are already normalized
Returns:
Dictionary mapping image keys to their corresponding encoded features
"""
- if normalize:
- obs = self.input_normalization(obs)
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
out = self.image_encoder(batched)
chunks = torch.chunk(out, len(self.image_keys), dim=0)
@@ -746,7 +713,6 @@ class CriticEnsemble(nn.Module):
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
- output_normalization (nn.Module): normalization layer for actions.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
@@ -756,13 +722,11 @@ class CriticEnsemble(nn.Module):
self,
encoder: SACObservationEncoder,
ensemble: list[CriticHead],
- output_normalization: nn.Module,
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
- self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
def forward(
@@ -774,11 +738,6 @@ class CriticEnsemble(nn.Module):
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
- # NOTE: We normalize actions it helps for sample efficiency
- actions: dict[str, torch.tensor] = {"action": actions}
- # NOTE: Normalization layer took dict in input and outputs a dict that why
- actions = self.output_normalization(actions)["action"]
- actions = actions.to(device)
obs_enc = self.encoder(observations, cache=observation_features)
@@ -1102,15 +1061,3 @@ class TanhMultivariateNormalDiag(TransformedDistribution):
x = transform(x)
return x
-
-
-def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
- converted_params = {}
- for outer_key, inner_dict in normalization_params.items():
- converted_params[outer_key] = {}
- for key, value in inner_dict.items():
- converted_params[outer_key][key] = torch.tensor(value)
- if "image" in outer_key:
- converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
-
- return converted_params
diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py
new file mode 100644
index 000000000..cf90e3cb4
--- /dev/null
+++ b/src/lerobot/policies/sac/processor_sac.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import torch
+
+from lerobot.policies.sac.configuration_sac import SACConfig
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ RenameObservationsProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+def make_sac_pre_post_processors(
+ config: SACConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the SAC policy.
+
+ The pre-processing pipeline prepares input data for the model by:
+ 1. Renaming features to match pretrained configurations.
+ 2. Normalizing input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Moving all data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving data to the CPU.
+ 2. Unnormalizing the output features to their original scale.
+
+ Args:
+ config: The configuration object for the SAC policy.
+ dataset_stats: A dictionary of statistics for normalization.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ # Add remaining processors
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}),
+ AddBatchDimensionProcessorStep(),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py
index fc53283b3..9b76b8037 100644
--- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py
+++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py
@@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
+from lerobot.utils.constants import OBS_IMAGE
@PreTrainedConfig.register_subclass(name="reward_classifier")
@@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig):
def validate_features(self) -> None:
"""Validate feature configurations."""
- has_image = any(key.startswith("observation.image") for key in self.input_features)
+ has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features)
if not has_image:
raise ValueError(
"You must provide an image observation (key starting with 'observation.image') in the input features"
diff --git a/src/lerobot/policies/sac/reward_model/modeling_classifier.py b/src/lerobot/policies/sac/reward_model/modeling_classifier.py
index cadd1c9f2..dba6a174b 100644
--- a/src/lerobot/policies/sac/reward_model/modeling_classifier.py
+++ b/src/lerobot/policies/sac/reward_model/modeling_classifier.py
@@ -19,10 +19,9 @@ import logging
import torch
from torch import Tensor, nn
-from lerobot.constants import OBS_IMAGE, REWARD
-from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
+from lerobot.utils.constants import OBS_IMAGE, REWARD
class ClassifierOutput:
@@ -108,22 +107,12 @@ class Classifier(PreTrainedPolicy):
def __init__(
self,
config: RewardClassifierConfig,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
- # Initialize normalization (standardized with the policy framework)
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
-
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
@@ -247,10 +236,6 @@ class Classifier(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
- # Normalize inputs if needed
- batch = self.normalize_inputs(batch)
- batch = self.normalize_targets(batch)
-
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py
new file mode 100644
index 000000000..c2a34eab2
--- /dev/null
+++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py
@@ -0,0 +1,82 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import torch
+
+from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
+from lerobot.processor import (
+ DeviceProcessorStep,
+ IdentityProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+
+
+def make_classifier_processor(
+ config: RewardClassifierConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the reward classifier.
+
+ The pre-processing pipeline prepares input data for the classifier by:
+ 1. Normalizing both input and output features based on dataset statistics.
+ 2. Moving the data to the specified device.
+
+ The post-processing pipeline handles the classifier's output by:
+ 1. Moving the data to the CPU.
+ 2. Applying an identity step, as no unnormalization is needed for the output logits.
+
+ Args:
+ config: The configuration object for the RewardClassifier.
+ dataset_stats: A dictionary of statistics for normalization.
+ preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
+ postprocessor_kwargs: Additional arguments for the post-processor pipeline.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ input_steps = [
+ NormalizerProcessorStep(
+ features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ NormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device=config.device),
+ ]
+ output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
+
+ return (
+ PolicyProcessorPipeline(
+ steps=input_steps,
+ name="classifier_preprocessor",
+ ),
+ PolicyProcessorPipeline(
+ steps=output_steps,
+ name="classifier_postprocessor",
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/smolvla/README.md b/src/lerobot/policies/smolvla/README.md
new file mode 120000
index 000000000..f8de40269
--- /dev/null
+++ b/src/lerobot/policies/smolvla/README.md
@@ -0,0 +1 @@
+../../../../docs/source/policy_smolvla_README.md
\ No newline at end of file
diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py
index 571900c4a..eedf477a5 100644
--- a/src/lerobot/policies/smolvla/configuration_smolvla.py
+++ b/src/lerobot/policies/smolvla/configuration_smolvla.py
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
+from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("smolvla")
@@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig):
def validate_features(self) -> None:
for i in range(self.empty_cameras):
- key = f"observation.images.empty_camera_{i}"
+ key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py
index 11bb8bf52..23fc3ca4f 100644
--- a/src/lerobot/policies/smolvla/modeling_smolvla.py
+++ b/src/lerobot/policies/smolvla/modeling_smolvla.py
@@ -28,7 +28,7 @@ pip install -e ".[smolvla]"
Example of finetuning the smolvla pretrained model (`smolvla_base`):
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
@@ -38,7 +38,7 @@ python -m lerobot.scripts.train \
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
and an action expert.
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--policy.type=smolvla \
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
--batch_size=64 \
@@ -53,125 +53,21 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
"""
import math
-import os
-import re
from collections import deque
-import safetensors
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
-from transformers import AutoProcessor
-from lerobot.constants import ACTION, OBS_STATE
-from lerobot.policies.normalize import (
- Normalize,
- Unnormalize,
-)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
populate_queues,
)
+from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.utils import get_safe_dtype
-# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
-_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
-
-
-def canonicalise(k: str) -> str:
- """
- Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
- normalisation-buffer key.
- """
- return _VARIANT_RE.sub(".buffer_", k)
-
-
-def standardise_state_dict(
- checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
-) -> tuple[dict[str, torch.Tensor], list[str]]:
- """
- • Re-keys `checkpoint ` so that every entry matches the *reference* key set.
- • If several variant keys collapse to the same canonical name we keep the
- first one and log the collision.
- • Returns the new dict + a list of entries that could not be matched.
- """
- out, collisions, unmatched = {}, {}, []
-
- for k, v in checkpoint.items():
- canon = canonicalise(k)
- if canon in ref_keys:
- if canon in out: # duplicate after collapsing
- collisions.setdefault(canon, []).append(k)
- else:
- out[canon] = v
- else:
- unmatched.append(k)
-
- if verbose:
- for canon, variants in collisions.items():
- print(f"[standardise_state_dict] '{canon}' ← {variants}")
- if unmatched:
- print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
-
- out.update({k: checkpoint[k] for k in unmatched})
- return out, unmatched
-
-
-def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
- """
- Renames keys in a checkpoint dictionary based on the given rename string.
-
- Args:
- checkpoint (dict): The checkpoint dictionary.
- rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
-
- Returns:
- dict: The modified checkpoint with renamed keys.
- """
-
- rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
-
- new_checkpoint = {}
- for k, v in checkpoint.items():
- for old_key, new_key in rename_dict.items():
- if old_key in k:
- k = k.replace(old_key, new_key)
- new_checkpoint[k] = v
- return new_checkpoint
-
-
-def load_smolvla(
- model: torch.nn.Module,
- filename: str | os.PathLike,
- *,
- device: str = "cpu",
- checkpoint_keys_mapping: str = "",
-) -> torch.nn.Module:
- state_dict = safetensors.torch.load_file(filename, device=device)
-
- # Optional user-supplied renames (e.g. "model._orig_mod.//model.")
- if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
- state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
-
- state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
-
- # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
- norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
- state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
-
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
-
- if not all(key.startswith(norm_keys) for key in missing) or unexpected:
- raise RuntimeError(
- "SmolVLA %d missing / %d unexpected keys",
- len(missing),
- len(unexpected),
- )
-
- return model
-
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
@@ -194,12 +90,6 @@ def create_sinusoidal_pos_embedding(
return pos_emb
-def sample_beta(alpha, beta, bsize, device):
- gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
- gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
- return gamma1 / (gamma1 + gamma2)
-
-
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
@@ -332,28 +222,17 @@ class SmolVLAPolicy(PreTrainedPolicy):
def __init__(
self,
config: SmolVLAConfig,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
- dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
- that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
@@ -363,34 +242,23 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
- # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
- @classmethod
- def _load_as_safetensor(
- cls,
- model: "SmolVLAPolicy",
- model_file: str,
- map_location: str,
- strict: bool,
- ):
- safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
- return load_smolvla(
- model,
- model_file,
- device=map_location,
- checkpoint_keys_mapping="model._orig_mod.//model.",
- )
-
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
+ # TODO: Check if this for loop is needed.
+ # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
+ # In the case of offline inference, we have the action in the batch
+ # that why without the k != ACTION check, it will raise an error because we are trying to stack
+ # on an empty container.
for k in batch:
- if k in self._queues:
+ if k in self._queues and k != ACTION:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
- lang_tokens, lang_masks = self.prepare_language(batch)
+ lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
+ lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
@@ -398,8 +266,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
- actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
-
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
@@ -409,10 +275,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
- batch = self.normalize_inputs(batch)
-
return batch
+ @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()
@@ -422,7 +287,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions = self._get_action_chunk(batch, noise)
return actions
- @torch.no_grad
+ @torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
@@ -450,11 +315,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
- batch = self.normalize_inputs(batch)
- batch = self.normalize_targets(batch)
+
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
- lang_tokens, lang_masks = self.prepare_language(batch)
+ lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
+ lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
@@ -518,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
img_masks.append(mask)
return images, img_masks
- def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
- """Tokenize the text input"""
- device = batch[OBS_STATE].device
- tasks = batch["task"]
- if isinstance(tasks, str):
- tasks = [tasks]
-
- if len(tasks) == 1:
- tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
-
- tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
-
- tokenized_prompt = self.language_tokenizer.__call__(
- tasks,
- padding=self.config.pad_language_to,
- padding_side="right",
- max_length=self.config.tokenizer_max_length,
- return_tensors="pt",
- )
- lang_tokens = tokenized_prompt["input_ids"].to(device=device)
- lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
-
- return lang_tokens, lang_masks
-
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
@@ -630,7 +471,7 @@ class VLAFlowMatching(nn.Module):
└──────────────────────────────┘
"""
- def __init__(self, config):
+ def __init__(self, config: SmolVLAConfig):
super().__init__()
self.config = config
@@ -684,9 +525,10 @@ class VLAFlowMatching(nn.Module):
return noise
def sample_time(self, bsize, device):
- time_beta = sample_beta(1.5, 1.0, bsize, device)
+ beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
+ time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
- return time.to(dtype=torch.float32, device=device)
+ return time
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py
new file mode 100644
index 000000000..3fc130aa1
--- /dev/null
+++ b/src/lerobot/policies/smolvla/processor_smolvla.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python
+
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import torch
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ ComplementaryDataProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ ProcessorStepRegistry,
+ RenameObservationsProcessorStep,
+ TokenizerProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+def make_smolvla_pre_post_processors(
+ config: SmolVLAConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the SmolVLA policy.
+
+ The pre-processing pipeline prepares input data for the model by:
+ 1. Renaming features to match pretrained configurations.
+ 2. Normalizing input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Ensuring the language task description ends with a newline character.
+ 5. Tokenizing the language task description.
+ 6. Moving all data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving data to the CPU.
+ 2. Unnormalizing the output actions to their original scale.
+
+ Args:
+ config: The configuration object for the SmolVLA policy.
+ dataset_stats: A dictionary of statistics for normalization.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
+ AddBatchDimensionProcessorStep(),
+ SmolVLANewLineProcessor(),
+ TokenizerProcessorStep(
+ tokenizer_name=config.vlm_model_name,
+ padding=config.pad_language_to,
+ padding_side="right",
+ max_length=config.tokenizer_max_length,
+ ),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
+
+
+@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
+class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
+ """
+ A processor step that ensures the 'task' description ends with a newline character.
+
+ This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
+ newline at the end of the prompt. It handles both single string tasks and lists
+ of string tasks.
+ """
+
+ def complementary_data(self, complementary_data):
+ if "task" not in complementary_data:
+ return complementary_data
+
+ task = complementary_data["task"]
+ if task is None:
+ return complementary_data
+
+ new_complementary_data = dict(complementary_data)
+
+ # Handle both string and list of strings
+ if isinstance(task, str):
+ # Single string: add newline if not present
+ if not task.endswith("\n"):
+ new_complementary_data["task"] = f"{task}\n"
+ elif isinstance(task, list) and all(isinstance(t, str) for t in task):
+ # List of strings: add newline to each if not present
+ new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
+ # If task is neither string nor list of strings, leave unchanged
+
+ return new_complementary_data
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py
index 07eae8089..f3d1a693a 100644
--- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py
+++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py
@@ -13,7 +13,6 @@
# limitations under the License.
import copy
-from typing import List, Optional
import torch
from torch import nn
@@ -403,12 +402,12 @@ class SmolVLMWithExpertModel(nn.Module):
def forward(
self,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: List[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- fill_kv_cache: Optional[bool] = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ inputs_embeds: list[torch.FloatTensor] = None,
+ use_cache: bool | None = None,
+ fill_kv_cache: bool | None = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
diff --git a/src/lerobot/policies/tdmpc/README.md b/src/lerobot/policies/tdmpc/README.md
new file mode 120000
index 000000000..413ea87b8
--- /dev/null
+++ b/src/lerobot/policies/tdmpc/README.md
@@ -0,0 +1 @@
+../../../../docs/source/policy_tdmpc_README.md
\ No newline at end of file
diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py
index 8b70b265d..195cf6154 100644
--- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py
+++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py
@@ -24,9 +24,9 @@ The comments in this code may sometimes refer to these references:
# ruff: noqa: N806
from collections import deque
+from collections.abc import Callable
from copy import deepcopy
from functools import partial
-from typing import Callable
import einops
import numpy as np
@@ -35,11 +35,10 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
-from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
-from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
+from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
class TDMPCPolicy(PreTrainedPolicy):
@@ -63,26 +62,19 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig
name = "tdmpc"
- def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
+ def __init__(
+ self,
+ config: TDMPCConfig,
+ ):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
- dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
- that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
-
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
@@ -99,18 +91,18 @@ class TDMPCPolicy(PreTrainedPolicy):
called on `env.reset()`
"""
self._queues = {
- "observation.state": deque(maxlen=1),
- "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
+ OBS_STATE: deque(maxlen=1),
+ ACTION: deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self.config.image_features:
- self._queues["observation.image"] = deque(maxlen=1)
+ self._queues[OBS_IMAGE] = deque(maxlen=1)
if self.config.env_state_feature:
- self._queues["observation.environment_state"] = deque(maxlen=1)
+ self._queues[OBS_ENV_STATE] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
- @torch.no_grad
+ @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
@@ -137,16 +129,21 @@ class TDMPCPolicy(PreTrainedPolicy):
actions = torch.clamp(actions, -1, +1)
- actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
- batch = self.normalize_inputs(batch)
+ # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
+ if ACTION in batch:
+ batch.pop(ACTION)
+
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
+ # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
+ if ACTION in batch:
+ batch.pop(ACTION)
self._queues = populate_queues(self._queues, batch)
@@ -315,11 +312,9 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
device = get_device_from_parameters(self)
- batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
- batch = self.normalize_targets(batch)
info = {}
@@ -330,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy):
action = batch[ACTION] # (t, b, action_dim)
reward = batch[REWARD] # (t, b)
- observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
+ observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
@@ -392,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
- * ~batch["observation.state_is_pad"][0]
+ * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
- * ~batch["observation.state_is_pad"][1:]
+ * ~batch[f"{OBS_STR}.state_is_pad"][1:]
)
.sum(0)
.mean()
@@ -408,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy):
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
- * ~batch["observation.state_is_pad"][0]
+ * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
@@ -424,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy):
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
- * ~batch["observation.state_is_pad"][0]
+ * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
- * ~batch["observation.state_is_pad"][1:]
+ * ~batch[f"{OBS_STR}.state_is_pad"][1:]
)
.sum(0)
.mean()
@@ -446,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
- * ~batch["observation.state_is_pad"][0]
+ * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
@@ -482,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy):
* mse
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
- * ~batch["observation.state_is_pad"][0]
+ * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py
new file mode 100644
index 000000000..9b6f97e50
--- /dev/null
+++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any
+
+import torch
+
+from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ RenameObservationsProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+def make_tdmpc_pre_post_processors(
+ config: TDMPCConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the TDMPC policy.
+
+ The pre-processing pipeline prepares input data for the model by:
+ 1. Renaming features to match pretrained configurations.
+ 2. Normalizing input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Moving all data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving data to the CPU.
+ 2. Unnormalizing the output features to their original scale.
+
+ Args:
+ config: The configuration object for the TDMPC policy.
+ dataset_stats: A dictionary of statistics for normalization.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}),
+ AddBatchDimensionProcessorStep(),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py
index 5659e8727..5a3994cdf 100644
--- a/src/lerobot/policies/utils.py
+++ b/src/lerobot/policies/utils.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from collections import deque
import torch
@@ -71,3 +72,16 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)
+
+
+def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None:
+ """Log missing and unexpected keys when loading a model.
+
+ Args:
+ missing_keys (list[str]): Keys that were expected but not found.
+ unexpected_keys (list[str]): Keys that were found but not expected.
+ """
+ if missing_keys:
+ logging.warning(f"Missing key(s) when loading model: {missing_keys}")
+ if unexpected_keys:
+ logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
diff --git a/src/lerobot/policies/vqbet/README.md b/src/lerobot/policies/vqbet/README.md
new file mode 120000
index 000000000..a4ae9291a
--- /dev/null
+++ b/src/lerobot/policies/vqbet/README.md
@@ -0,0 +1 @@
+../../../../docs/source/policy_vqbet_README.md
\ No newline at end of file
diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py
index d7a79f189..44ada9f17 100644
--- a/src/lerobot/policies/vqbet/configuration_vqbet.py
+++ b/src/lerobot/policies/vqbet/configuration_vqbet.py
@@ -82,7 +82,6 @@ class VQBeTConfig(PreTrainedConfig):
gpt_n_head: Number of headers of GPT
gpt_hidden_dim: Size of hidden dimensions of GPT
dropout: Dropout rate for GPT
- mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
offset_loss_weight: A constant that is multiplied to the offset loss
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
@@ -125,7 +124,6 @@ class VQBeTConfig(PreTrainedConfig):
gpt_n_head: int = 8
gpt_hidden_dim: int = 512
dropout: float = 0.1
- mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.0
primary_code_loss_weight: float = 5.0
secondary_code_loss_weight: float = 0.5
diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py
index c045ccbd2..91d609701 100644
--- a/src/lerobot/policies/vqbet/modeling_vqbet.py
+++ b/src/lerobot/policies/vqbet/modeling_vqbet.py
@@ -18,7 +18,7 @@
import warnings
from collections import deque
-from typing import Callable, List
+from collections.abc import Callable
import einops
import numpy as np
@@ -27,12 +27,11 @@ import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor, nn
-from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
-from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.vqbet_utils import GPT, ResidualVQ
+from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
# ruff: noqa: N806
@@ -48,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy):
def __init__(
self,
config: VQBeTConfig | None = None,
- dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -61,14 +59,6 @@ class VQBeTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
- self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
- self.normalize_targets = Normalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
- self.unnormalize_outputs = Unnormalize(
- config.output_features, config.normalization_mapping, dataset_stats
- )
-
self.vqbet = VQBeTModel(config)
self.reset()
@@ -124,14 +114,13 @@ class VQBeTPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.action_chunk_size),
}
- @torch.no_grad
+ @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
- actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
- @torch.no_grad
+ @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@@ -139,11 +128,16 @@ class VQBeTPolicy(PreTrainedPolicy):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
-
- batch = self.normalize_inputs(batch)
+ # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
+ if ACTION in batch:
+ batch.pop(ACTION)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
- batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
- # Note: It's important that this happens after stacking the images into a single key.
+ # NOTE: It's important that this happens after stacking the images into a single key.
+ batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
+ # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
+ if ACTION in batch:
+ batch.pop(ACTION)
+
self._queues = populate_queues(self._queues, batch)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -162,10 +156,8 @@ class VQBeTPolicy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
- batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
- batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
# loss: total loss of training RVQ
@@ -348,14 +340,12 @@ class VQBeTModel(nn.Module):
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
# Input validation.
- assert set(batch).issuperset({"observation.state", "observation.images"})
- batch_size, n_obs_steps = batch["observation.state"].shape[:2]
+ assert set(batch).issuperset({OBS_STATE, OBS_IMAGES})
+ batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
- img_features = self.rgb_encoder(
- einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
- )
+ img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
@@ -367,9 +357,7 @@ class VQBeTModel(nn.Module):
img_features
) # (batch, obs_step, number of different cameras, projection dims)
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
- input_tokens.append(
- self.state_projector(batch["observation.state"])
- ) # (batch, obs_step, projection dims)
+ input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
@@ -901,7 +889,7 @@ class MLP(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
- hidden_channels: List[int],
+ hidden_channels: list[int],
):
layers = []
in_dim = in_channels
diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py
new file mode 100644
index 000000000..1e19ff779
--- /dev/null
+++ b/src/lerobot/policies/vqbet/processor_vqbet.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+
+# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
+# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any
+
+import torch
+
+from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ RenameObservationsProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+def make_vqbet_pre_post_processors(
+ config: VQBeTConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Constructs pre-processor and post-processor pipelines for the VQ-BeT policy.
+
+ The pre-processing pipeline prepares input data for the model by:
+ 1. Renaming features, allowing customization to match pretrained configurations.
+ 2. Normalizing input and output features based on dataset statistics.
+ 3. Adding a batch dimension.
+ 4. Moving all data to the specified device.
+
+ The post-processing pipeline handles the model's output by:
+ 1. Moving data to the CPU.
+ 2. Unnormalizing the output features to their original scale.
+
+ Args:
+ config: The configuration object for the VQ-BeT policy.
+ dataset_stats: A dictionary of statistics for normalization.
+
+ Returns:
+ A tuple containing the configured pre-processor and post-processor pipelines.
+ """
+
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
+ AddBatchDimensionProcessorStep(),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/vqbet/vqbet_utils.py b/src/lerobot/policies/vqbet/vqbet_utils.py
index 03b02a280..7b13577f6 100644
--- a/src/lerobot/policies/vqbet/vqbet_utils.py
+++ b/src/lerobot/policies/vqbet/vqbet_utils.py
@@ -17,10 +17,10 @@
# limitations under the License.
import math
+from collections.abc import Callable
from functools import partial
from math import ceil
from random import randrange
-from typing import Callable
import torch
import torch.distributed as distributed
@@ -198,7 +198,7 @@ class GPT(nn.Module):
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
- print("number of parameters: {:.2f}M".format(n_params / 1e6))
+ print(f"number of parameters: {n_params / 1e6:.2f}M")
def forward(self, input, targets=None):
device = input.device
@@ -231,16 +231,6 @@ class GPT(nn.Module):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
- def crop_block_size(self, gpt_block_size):
- # model surgery to decrease the block size if necessary
- # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
- # but want to use a smaller block size for some smaller, simpler model
- assert gpt_block_size <= self.config.gpt_block_size
- self.config.gpt_block_size = gpt_block_size
- self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
- for block in self.transformer.h:
- block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
-
def configure_parameters(self):
"""
This long function is unfortunately doing something very simple and is being very defensive:
@@ -255,7 +245,7 @@ class GPT(nn.Module):
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _p in m.named_parameters():
- fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
+ fpn = f"{mn}.{pn}" if mn else pn # full param name
if pn.endswith("bias"):
# all biases will not be decayed
no_decay.add(fpn)
@@ -270,13 +260,11 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
- assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
- str(inter_params)
+ assert len(inter_params) == 0, (
+ f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
)
assert len(param_dict.keys() - union_params) == 0, (
- "parameters {} were not separated into either decay/no_decay set!".format(
- str(param_dict.keys() - union_params),
- )
+ f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
)
decay = [param_dict[pn] for pn in sorted(decay)]
diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py
new file mode 100644
index 000000000..be11ac1af
--- /dev/null
+++ b/src/lerobot/processor/__init__.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .batch_processor import AddBatchDimensionProcessorStep
+from .converters import (
+ batch_to_transition,
+ create_transition,
+ transition_to_batch,
+)
+from .core import (
+ EnvAction,
+ EnvTransition,
+ PolicyAction,
+ RobotAction,
+ RobotObservation,
+ TransitionKey,
+)
+from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
+from .device_processor import DeviceProcessorStep
+from .factory import (
+ make_default_processors,
+ make_default_robot_action_processor,
+ make_default_robot_observation_processor,
+ make_default_teleop_action_processor,
+)
+from .gym_action_processor import (
+ Numpy2TorchActionProcessorStep,
+ Torch2NumpyActionProcessorStep,
+)
+from .hil_processor import (
+ AddTeleopActionAsComplimentaryDataStep,
+ AddTeleopEventsAsInfoStep,
+ GripperPenaltyProcessorStep,
+ ImageCropResizeProcessorStep,
+ InterventionActionProcessorStep,
+ RewardClassifierProcessorStep,
+ TimeLimitProcessorStep,
+)
+from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
+from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
+from .observation_processor import VanillaObservationProcessorStep
+from .pipeline import (
+ ActionProcessorStep,
+ ComplementaryDataProcessorStep,
+ DataProcessorPipeline,
+ DoneProcessorStep,
+ IdentityProcessorStep,
+ InfoProcessorStep,
+ ObservationProcessorStep,
+ PolicyActionProcessorStep,
+ PolicyProcessorPipeline,
+ ProcessorKwargs,
+ ProcessorStep,
+ ProcessorStepRegistry,
+ RewardProcessorStep,
+ RobotActionProcessorStep,
+ RobotProcessorPipeline,
+ TruncatedProcessorStep,
+)
+from .policy_robot_bridge import (
+ PolicyActionToRobotActionProcessorStep,
+ RobotActionToPolicyActionProcessorStep,
+)
+from .rename_processor import RenameObservationsProcessorStep
+from .tokenizer_processor import TokenizerProcessorStep
+
+__all__ = [
+ "ActionProcessorStep",
+ "AddTeleopActionAsComplimentaryDataStep",
+ "AddTeleopEventsAsInfoStep",
+ "ComplementaryDataProcessorStep",
+ "batch_to_transition",
+ "create_transition",
+ "DeviceProcessorStep",
+ "DoneProcessorStep",
+ "EnvAction",
+ "EnvTransition",
+ "GripperPenaltyProcessorStep",
+ "hotswap_stats",
+ "IdentityProcessorStep",
+ "ImageCropResizeProcessorStep",
+ "InfoProcessorStep",
+ "InterventionActionProcessorStep",
+ "JointVelocityProcessorStep",
+ "make_default_processors",
+ "make_default_teleop_action_processor",
+ "make_default_robot_action_processor",
+ "make_default_robot_observation_processor",
+ "MapDeltaActionToRobotActionStep",
+ "MapTensorToDeltaActionDictStep",
+ "MotorCurrentProcessorStep",
+ "NormalizerProcessorStep",
+ "Numpy2TorchActionProcessorStep",
+ "ObservationProcessorStep",
+ "PolicyAction",
+ "PolicyActionProcessorStep",
+ "PolicyProcessorPipeline",
+ "ProcessorKwargs",
+ "ProcessorStep",
+ "ProcessorStepRegistry",
+ "RobotAction",
+ "RobotActionProcessorStep",
+ "RobotObservation",
+ "RenameObservationsProcessorStep",
+ "RewardClassifierProcessorStep",
+ "RewardProcessorStep",
+ "DataProcessorPipeline",
+ "TimeLimitProcessorStep",
+ "AddBatchDimensionProcessorStep",
+ "RobotProcessorPipeline",
+ "TokenizerProcessorStep",
+ "Torch2NumpyActionProcessorStep",
+ "RobotActionToPolicyActionProcessorStep",
+ "PolicyActionToRobotActionProcessorStep",
+ "transition_to_batch",
+ "TransitionKey",
+ "TruncatedProcessorStep",
+ "UnnormalizerProcessorStep",
+ "VanillaObservationProcessorStep",
+]
diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py
new file mode 100644
index 000000000..e1a90421f
--- /dev/null
+++ b/src/lerobot/processor/batch_processor.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script defines processor steps for adding a batch dimension to various components of an environment transition.
+
+These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model.
+"""
+
+from dataclasses import dataclass, field
+
+from torch import Tensor
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
+
+from .core import EnvTransition, PolicyAction
+from .pipeline import (
+ ComplementaryDataProcessorStep,
+ ObservationProcessorStep,
+ PolicyActionProcessorStep,
+ ProcessorStep,
+ ProcessorStepRegistry,
+ TransitionKey,
+)
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="to_batch_processor_action")
+class AddBatchDimensionActionStep(PolicyActionProcessorStep):
+ """
+ Processor step to add a batch dimension to a 1D tensor action.
+
+ This is useful for creating a batch of size 1 from a single action sample.
+ """
+
+ def action(self, action: PolicyAction) -> PolicyAction:
+ """
+ Adds a batch dimension to the action if it's a 1D tensor.
+
+ Args:
+ action: The action tensor.
+
+ Returns:
+ The action tensor with an added batch dimension.
+ """
+ if action.dim() != 1:
+ return action
+ return action.unsqueeze(0)
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Returns the input features unchanged.
+
+ Adding a batch dimension does not alter the feature definition.
+
+ Args:
+ features: A dictionary of policy features.
+
+ Returns:
+ The original dictionary of policy features.
+ """
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="to_batch_processor_observation")
+class AddBatchDimensionObservationStep(ObservationProcessorStep):
+ """
+ Processor step to add a batch dimension to observations.
+
+ It handles different types of observations:
+ - State vectors (1D tensors).
+ - Single images (3D tensors).
+ - Dictionaries of multiple images (3D tensors).
+ """
+
+ def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
+ """
+ Adds a batch dimension to tensor-based observations in the observation dictionary.
+
+ Args:
+ observation: The observation dictionary.
+
+ Returns:
+ The observation dictionary with batch dimensions added to tensors.
+ """
+ # Process state observations - add batch dim if 1D
+ for state_key in [OBS_STATE, OBS_ENV_STATE]:
+ if state_key in observation:
+ state_value = observation[state_key]
+ if isinstance(state_value, Tensor) and state_value.dim() == 1:
+ observation[state_key] = state_value.unsqueeze(0)
+
+ # Process single image observation - add batch dim if 3D
+ if OBS_IMAGE in observation:
+ image_value = observation[OBS_IMAGE]
+ if isinstance(image_value, Tensor) and image_value.dim() == 3:
+ observation[OBS_IMAGE] = image_value.unsqueeze(0)
+
+ # Process multiple image observations - add batch dim if 3D
+ for key, value in observation.items():
+ if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
+ observation[key] = value.unsqueeze(0)
+ return observation
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Returns the input features unchanged.
+
+ Adding a batch dimension does not alter the feature definition.
+
+ Args:
+ features: A dictionary of policy features.
+
+ Returns:
+ The original dictionary of policy features.
+ """
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
+class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
+ """
+ Processor step to add a batch dimension to complementary data fields.
+
+ Handles specific keys like 'task', 'index', and 'task_index' to make them batched.
+ - 'task' (str) is wrapped in a list.
+ - 'index' and 'task_index' (0D tensors) get a batch dimension.
+ """
+
+ def complementary_data(self, complementary_data: dict) -> dict:
+ """
+ Adds a batch dimension to specific fields in the complementary data dictionary.
+
+ Args:
+ complementary_data: The complementary data dictionary.
+
+ Returns:
+ The complementary data dictionary with batch dimensions added.
+ """
+ # Process task field - wrap string in list to add batch dimension
+ if "task" in complementary_data:
+ task_value = complementary_data["task"]
+ if isinstance(task_value, str):
+ complementary_data["task"] = [task_value]
+
+ # Process index field - add batch dim if 0D
+ if "index" in complementary_data:
+ index_value = complementary_data["index"]
+ if isinstance(index_value, Tensor) and index_value.dim() == 0:
+ complementary_data["index"] = index_value.unsqueeze(0)
+
+ # Process task_index field - add batch dim if 0D
+ if "task_index" in complementary_data:
+ task_index_value = complementary_data["task_index"]
+ if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
+ complementary_data["task_index"] = task_index_value.unsqueeze(0)
+ return complementary_data
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Returns the input features unchanged.
+
+ Adding a batch dimension does not alter the feature definition.
+
+ Args:
+ features: A dictionary of policy features.
+
+ Returns:
+ The original dictionary of policy features.
+ """
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="to_batch_processor")
+class AddBatchDimensionProcessorStep(ProcessorStep):
+ """
+ A composite processor step that adds a batch dimension to the entire environment transition.
+
+ This step combines individual processors for actions, observations, and complementary data
+ to create a batched transition (batch size 1) from a single-instance transition.
+
+ Attributes:
+ to_batch_action_processor: Processor for the action component.
+ to_batch_observation_processor: Processor for the observation component.
+ to_batch_complementary_data_processor: Processor for the complementary data component.
+ """
+
+ to_batch_action_processor: AddBatchDimensionActionStep = field(
+ default_factory=AddBatchDimensionActionStep
+ )
+ to_batch_observation_processor: AddBatchDimensionObservationStep = field(
+ default_factory=AddBatchDimensionObservationStep
+ )
+ to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field(
+ default_factory=AddBatchDimensionComplementaryDataStep
+ )
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """
+ Applies the batching process to all relevant parts of an environment transition.
+
+ Args:
+ transition: The environment transition to process.
+
+ Returns:
+ The environment transition with a batch dimension added.
+ """
+ if transition[TransitionKey.ACTION] is not None:
+ transition = self.to_batch_action_processor(transition)
+ if transition[TransitionKey.OBSERVATION] is not None:
+ transition = self.to_batch_observation_processor(transition)
+ if transition[TransitionKey.COMPLEMENTARY_DATA] is not None:
+ transition = self.to_batch_complementary_data_processor(transition)
+ return transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Returns the input features unchanged.
+
+ Adding a batch dimension does not alter the feature definition.
+
+ Args:
+ features: A dictionary of policy features.
+
+ Returns:
+ The original dictionary of policy features.
+ """
+ # NOTE: We ignore the batch dimension when transforming features
+ return features
diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py
new file mode 100644
index 000000000..6b0b67598
--- /dev/null
+++ b/src/lerobot/processor/converters.py
@@ -0,0 +1,414 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from functools import singledispatch
+from typing import Any
+
+import numpy as np
+import torch
+
+from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED
+
+from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
+
+
+@singledispatch
+def to_tensor(
+ value: Any,
+ *,
+ dtype: torch.dtype | None = torch.float32,
+ device: torch.device | str | None = None,
+) -> torch.Tensor:
+ """
+ Convert various data types to PyTorch tensors with configurable options.
+
+ This is a unified tensor conversion function using single dispatch to handle
+ different input types appropriately.
+
+ Args:
+ value: Input value to convert (tensor, array, scalar, sequence, etc.).
+ dtype: Target tensor dtype. If None, preserves original dtype.
+ device: Target device for the tensor.
+
+ Returns:
+ A PyTorch tensor.
+
+ Raises:
+ TypeError: If the input type is not supported.
+ """
+ raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
+
+
+@to_tensor.register(torch.Tensor)
+def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
+ """Handle conversion for existing PyTorch tensors."""
+ if dtype is not None:
+ value = value.to(dtype=dtype)
+ if device is not None:
+ value = value.to(device=device)
+ return value
+
+
+@to_tensor.register(np.ndarray)
+def _(
+ value: np.ndarray,
+ *,
+ dtype=torch.float32,
+ device=None,
+ **kwargs,
+) -> torch.Tensor:
+ """Handle conversion for numpy arrays."""
+ # Check for numpy scalars (0-dimensional arrays) and treat them as scalars.
+ if value.ndim == 0:
+ # Numpy scalars should be converted to 0-dimensional tensors.
+ scalar_value = value.item()
+ return torch.tensor(scalar_value, dtype=dtype, device=device)
+
+ # Create tensor from numpy array.
+ tensor = torch.from_numpy(value)
+
+ # Apply dtype and device conversion if specified.
+ if dtype is not None:
+ tensor = tensor.to(dtype=dtype)
+ if device is not None:
+ tensor = tensor.to(device=device)
+
+ return tensor
+
+
+@to_tensor.register(int)
+@to_tensor.register(float)
+@to_tensor.register(np.integer)
+@to_tensor.register(np.floating)
+def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
+ """Handle conversion for scalar values including numpy scalars."""
+ return torch.tensor(value, dtype=dtype, device=device)
+
+
+@to_tensor.register(list)
+@to_tensor.register(tuple)
+def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
+ """Handle conversion for sequences (lists, tuples)."""
+ return torch.tensor(value, dtype=dtype, device=device)
+
+
+@to_tensor.register(dict)
+def _(value: dict, *, device=None, **kwargs) -> dict:
+ """Handle conversion for dictionaries by recursively converting their values to tensors."""
+ if not value:
+ return {}
+
+ result = {}
+ for key, sub_value in value.items():
+ if sub_value is None:
+ continue
+
+ if isinstance(sub_value, dict):
+ # Recursively process nested dictionaries.
+ result[key] = to_tensor(
+ sub_value,
+ device=device,
+ **kwargs,
+ )
+ continue
+
+ # Convert individual values to tensors.
+ result[key] = to_tensor(
+ sub_value,
+ device=device,
+ **kwargs,
+ )
+ return result
+
+
+def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
+ """
+ Convert a PyTorch tensor to a numpy array or scalar if applicable.
+
+ If the input is not a tensor, it is returned unchanged.
+
+ Args:
+ x: The input, which can be a tensor or any other type.
+
+ Returns:
+ A numpy array, a scalar, or the original input.
+ """
+ if isinstance(x, torch.Tensor):
+ return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
+ return x
+
+
+def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
+ """
+ Extract complementary data from a batch dictionary.
+
+ This includes padding flags, task description, and indices.
+
+ Args:
+ batch: The batch dictionary.
+
+ Returns:
+ A dictionary with the extracted complementary data.
+ """
+ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
+ task_key = {"task": batch["task"]} if "task" in batch else {}
+ index_key = {"index": batch["index"]} if "index" in batch else {}
+ task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
+
+ return {**pad_keys, **task_key, **index_key, **task_index_key}
+
+
+def create_transition(
+ observation: dict[str, Any] | None = None,
+ action: PolicyAction | RobotAction | None = None,
+ reward: float = 0.0,
+ done: bool = False,
+ truncated: bool = False,
+ info: dict[str, Any] | None = None,
+ complementary_data: dict[str, Any] | None = None,
+) -> EnvTransition:
+ """
+ Create an `EnvTransition` dictionary with sensible defaults.
+
+ Args:
+ observation: Observation dictionary.
+ action: Action dictionary.
+ reward: Scalar reward value.
+ done: Episode termination flag.
+ truncated: Episode truncation flag.
+ info: Additional info dictionary.
+ complementary_data: Complementary data dictionary.
+
+ Returns:
+ A complete `EnvTransition` dictionary.
+ """
+ return {
+ TransitionKey.OBSERVATION: observation,
+ TransitionKey.ACTION: action,
+ TransitionKey.REWARD: reward,
+ TransitionKey.DONE: done,
+ TransitionKey.TRUNCATED: truncated,
+ TransitionKey.INFO: info if info is not None else {},
+ TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
+ }
+
+
+def robot_action_observation_to_transition(
+ action_observation: tuple[RobotAction, RobotObservation],
+) -> EnvTransition:
+ """
+ Convert a raw robot action and observation dictionary into a standardized `EnvTransition`.
+
+ Args:
+ action: The raw action dictionary from a teleoperation device or controller.
+ observation: The raw observation dictionary from the environment.
+
+ Returns:
+ An `EnvTransition` containing the formatted observation.
+ """
+ if not isinstance(action_observation, tuple):
+ raise ValueError("action_observation should be a tuple type with an action and observation")
+
+ action, observation = action_observation
+
+ if action is not None and not isinstance(action, dict):
+ raise ValueError(f"Action should be a RobotAction type got {type(action)}")
+
+ if observation is not None and not isinstance(observation, dict):
+ raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}")
+
+ return create_transition(action=action, observation=observation)
+
+
+def robot_action_to_transition(action: RobotAction) -> EnvTransition:
+ """
+ Convert a raw robot action dictionary into a standardized `EnvTransition`.
+
+ Args:
+ action: The raw action dictionary from a teleoperation device or controller.
+
+ Returns:
+ An `EnvTransition` containing the formatted action.
+ """
+ if not isinstance(action, dict):
+ raise ValueError(f"Action should be a RobotAction type got {type(action)}")
+ return create_transition(action=action)
+
+
+def observation_to_transition(observation: RobotObservation) -> EnvTransition:
+ """
+ Convert a raw robot observation dictionary into a standardized `EnvTransition`.
+
+ Args:
+ observation: The raw observation dictionary from the environment.
+
+ Returns:
+ An `EnvTransition` containing the formatted observation.
+ """
+ if not isinstance(observation, dict):
+ raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}")
+ return create_transition(observation=observation)
+
+
+def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
+ """
+ Extract a raw robot action dictionary for a robot from an `EnvTransition`.
+
+ This function searches for keys in the format "action.*.pos" or "action.*.vel"
+ and converts them into a flat dictionary suitable for sending to a robot controller.
+
+ Args:
+ transition: The `EnvTransition` containing the action.
+
+ Returns:
+ A dictionary representing the raw robot action.
+ """
+ if not isinstance(transition, dict):
+ raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
+
+ action = transition.get(TransitionKey.ACTION)
+ if not isinstance(action, dict):
+ raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}")
+ return transition.get(TransitionKey.ACTION)
+
+
+def transition_to_policy_action(transition: EnvTransition) -> PolicyAction:
+ """
+ Convert an `EnvTransition` to a `PolicyAction`.
+ """
+ if not isinstance(transition, dict):
+ raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
+
+ action = transition.get(TransitionKey.ACTION)
+ if not isinstance(action, PolicyAction):
+ raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
+ return action
+
+
+def transition_to_observation(transition: EnvTransition) -> RobotObservation:
+ """
+ Convert an `EnvTransition` to a `RobotObservation`.
+ """
+ if not isinstance(transition, dict):
+ raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
+
+ observation = transition.get(TransitionKey.OBSERVATION)
+ if not isinstance(observation, dict):
+ raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}")
+ return observation
+
+
+def policy_action_to_transition(action: PolicyAction) -> EnvTransition:
+ """
+ Convert a `PolicyAction` to an `EnvTransition`.
+ """
+ if not isinstance(action, PolicyAction):
+ raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
+ return create_transition(action=action)
+
+
+def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
+ """
+ Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`.
+
+ This function maps recognized keys from a batch to the `EnvTransition` structure,
+ filling in missing keys with sensible defaults.
+
+ Args:
+ batch: A batch dictionary.
+
+ Returns:
+ An `EnvTransition` dictionary.
+
+ Raises:
+ ValueError: If the input is not a dictionary.
+ """
+
+ # Validate input type.
+ if not isinstance(batch, dict):
+ raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
+
+ action = batch.get(ACTION)
+ if action is not None and not isinstance(action, PolicyAction):
+ raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
+
+ # Extract observation and complementary data keys.
+ observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
+ complementary_data = _extract_complementary_data(batch)
+
+ return create_transition(
+ observation=observation_keys if observation_keys else None,
+ action=batch.get(ACTION),
+ reward=batch.get(REWARD, 0.0),
+ done=batch.get(DONE, False),
+ truncated=batch.get(TRUNCATED, False),
+ info=batch.get("info", {}),
+ complementary_data=complementary_data if complementary_data else None,
+ )
+
+
+def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
+ """
+ Convert an `EnvTransition` back to the canonical batch format used in LeRobot.
+
+ This is the inverse of `batch_to_transition`.
+
+ Args:
+ transition: The `EnvTransition` to convert.
+
+ Returns:
+ A batch dictionary with canonical LeRobot field names.
+ """
+ if not isinstance(transition, dict):
+ raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
+
+ batch = {
+ ACTION: transition.get(TransitionKey.ACTION),
+ REWARD: transition.get(TransitionKey.REWARD, 0.0),
+ DONE: transition.get(TransitionKey.DONE, False),
+ TRUNCATED: transition.get(TransitionKey.TRUNCATED, False),
+ "info": transition.get(TransitionKey.INFO, {}),
+ }
+
+ # Add complementary data.
+ comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
+ if comp_data:
+ batch.update(comp_data)
+
+ # Flatten observation dictionary.
+ observation = transition.get(TransitionKey.OBSERVATION)
+ if isinstance(observation, dict):
+ batch.update(observation)
+
+ return batch
+
+
+def identity_transition(transition: EnvTransition) -> EnvTransition:
+ """
+ An identity function for transitions, returning the input unchanged.
+
+ Useful as a default or placeholder in processing pipelines.
+
+ Args:
+ tr: An `EnvTransition`.
+
+ Returns:
+ The same `EnvTransition`.
+ """
+ return transition
diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py
new file mode 100644
index 000000000..679ba8c54
--- /dev/null
+++ b/src/lerobot/processor/core.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from enum import Enum
+from typing import Any, TypeAlias, TypedDict
+
+import numpy as np
+import torch
+
+
+class TransitionKey(str, Enum):
+ """Keys for accessing EnvTransition dictionary components."""
+
+ # TODO(Steven): Use consts
+ OBSERVATION = "observation"
+ ACTION = "action"
+ REWARD = "reward"
+ DONE = "done"
+ TRUNCATED = "truncated"
+ INFO = "info"
+ COMPLEMENTARY_DATA = "complementary_data"
+
+
+PolicyAction: TypeAlias = torch.Tensor
+RobotAction: TypeAlias = dict[str, Any]
+EnvAction: TypeAlias = np.ndarray
+RobotObservation: TypeAlias = dict[str, Any]
+
+
+EnvTransition = TypedDict(
+ "EnvTransition",
+ {
+ TransitionKey.OBSERVATION.value: dict[str, Any] | None,
+ TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None,
+ TransitionKey.REWARD.value: float | torch.Tensor | None,
+ TransitionKey.DONE.value: bool | torch.Tensor | None,
+ TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
+ TransitionKey.INFO.value: dict[str, Any] | None,
+ TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
+ },
+)
diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py
new file mode 100644
index 000000000..a8395637c
--- /dev/null
+++ b/src/lerobot/processor/delta_action_processor.py
@@ -0,0 +1,143 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+
+from .core import PolicyAction, RobotAction
+from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
+
+
+@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
+@dataclass
+class MapTensorToDeltaActionDictStep(ActionProcessorStep):
+ """
+ Maps a flat action tensor from a policy to a structured delta action dictionary.
+
+ This step is typically used after a policy outputs a continuous action vector.
+ It decomposes the vector into named components for delta movements of the
+ end-effector (x, y, z) and optionally the gripper.
+
+ Attributes:
+ use_gripper: If True, assumes the 4th element of the tensor is the
+ gripper action.
+ """
+
+ use_gripper: bool = True
+
+ def action(self, action: PolicyAction) -> RobotAction:
+ if not isinstance(action, PolicyAction):
+ raise ValueError("Only PolicyAction is supported for this processor")
+
+ if action.dim() > 1:
+ action = action.squeeze(0)
+
+ # TODO (maractingi): add rotation
+ delta_action = {
+ "delta_x": action[0].item(),
+ "delta_y": action[1].item(),
+ "delta_z": action[2].item(),
+ }
+ if self.use_gripper:
+ delta_action["gripper"] = action[3].item()
+ return delta_action
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ for axis in ["x", "y", "z"]:
+ features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+
+ if self.use_gripper:
+ features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+ return features
+
+
+@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
+@dataclass
+class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
+ """
+ Maps delta actions from teleoperators to robot target actions for inverse kinematics.
+
+ This step converts a dictionary of delta movements (e.g., from a gamepad)
+ into a target action format that includes an "enabled" flag and target
+ end-effector positions. It also handles scaling and noise filtering.
+
+ Attributes:
+ position_scale: A factor to scale the delta position inputs.
+ noise_threshold: The magnitude below which delta inputs are considered noise
+ and do not trigger an "enabled" state.
+ """
+
+ # Scale factors for delta movements
+ position_scale: float = 1.0
+ noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
+
+ def action(self, action: RobotAction) -> RobotAction:
+ # NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
+ # TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
+ delta_x = action.pop("delta_x")
+ delta_y = action.pop("delta_y")
+ delta_z = action.pop("delta_z")
+ gripper = action.pop("gripper")
+
+ # Determine if the teleoperator is actively providing input
+ # Consider enabled if any significant movement delta is detected
+ position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
+ enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
+
+ # Scale the deltas appropriately
+ scaled_delta_x = delta_x * self.position_scale
+ scaled_delta_y = delta_y * self.position_scale
+ scaled_delta_z = delta_z * self.position_scale
+
+ # For gamepad/keyboard, we don't have rotation input, so set to 0
+ # These could be extended in the future for more sophisticated teleoperators
+ target_wx = 0.0
+ target_wy = 0.0
+ target_wz = 0.0
+
+ # Update action with robot target format
+ action = {
+ "enabled": enabled,
+ "target_x": scaled_delta_x,
+ "target_y": scaled_delta_y,
+ "target_z": scaled_delta_z,
+ "target_wx": target_wx,
+ "target_wy": target_wy,
+ "target_wz": target_wz,
+ "gripper_vel": float(gripper),
+ }
+
+ return action
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ for axis in ["x", "y", "z", "gripper"]:
+ features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
+
+ for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]:
+ features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+
+ return features
diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py
new file mode 100644
index 000000000..2d0dd0880
--- /dev/null
+++ b/src/lerobot/processor/device_processor.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script defines a processor step for moving environment transition data to a specific torch device and casting
+its floating-point precision.
+"""
+
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.utils.utils import get_safe_torch_device
+
+from .core import EnvTransition, PolicyAction, TransitionKey
+from .pipeline import ProcessorStep, ProcessorStepRegistry
+
+
+@ProcessorStepRegistry.register("device_processor")
+@dataclass
+class DeviceProcessorStep(ProcessorStep):
+ """
+ Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their
+ floating-point data type.
+
+ This is crucial for preparing data for model training or inference on hardware like GPUs.
+
+ Attributes:
+ device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0").
+ float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16").
+ If None, the dtype is not changed.
+ """
+
+ device: str = "cpu"
+ float_dtype: str | None = None
+
+ DTYPE_MAPPING = {
+ "float16": torch.float16,
+ "float32": torch.float32,
+ "float64": torch.float64,
+ "bfloat16": torch.bfloat16,
+ "half": torch.float16,
+ "float": torch.float32,
+ "double": torch.float64,
+ }
+
+ def __post_init__(self):
+ """
+ Initializes the processor by converting string configurations to torch objects.
+
+ This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the
+ `float_dtype` string, converting it to a `torch.dtype` object.
+ """
+ self.tensor_device: torch.device = get_safe_torch_device(self.device)
+ # Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0")
+ self.device = self.tensor_device.type
+ self.non_blocking = "cuda" in str(self.device)
+
+ # Validate and convert float_dtype string to torch dtype
+ if self.float_dtype is not None:
+ if self.float_dtype not in self.DTYPE_MAPPING:
+ raise ValueError(
+ f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
+ )
+ self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
+ else:
+ self._target_float_dtype = None
+
+ def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Moves a single tensor to the target device and casts its dtype.
+
+ Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than
+ the target, which is useful when using frameworks like Accelerate.
+
+ Args:
+ tensor: The input torch.Tensor.
+
+ Returns:
+ The processed tensor on the correct device and with the correct dtype.
+ """
+ # Determine target device
+ if tensor.is_cuda and self.tensor_device.type == "cuda":
+ # Both tensor and target are on GPU - preserve tensor's GPU placement.
+ # This handles multi-GPU scenarios where Accelerate has already placed
+ # tensors on the correct GPU for each process.
+ target_device = tensor.device
+ else:
+ # Either tensor is on CPU, or we're configured for CPU.
+ # In both cases, use the configured device.
+ target_device = self.tensor_device
+
+ # MPS workaround: Convert float64 to float32 since MPS doesn't support float64
+ if target_device.type == "mps" and tensor.dtype == torch.float64:
+ tensor = tensor.to(dtype=torch.float32)
+
+ # Only move if necessary
+ if tensor.device != target_device:
+ tensor = tensor.to(target_device, non_blocking=self.non_blocking)
+
+ # Convert float dtype if specified and tensor is floating point
+ if self._target_float_dtype is not None and tensor.is_floating_point():
+ tensor = tensor.to(dtype=self._target_float_dtype)
+
+ return tensor
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """
+ Applies device and dtype conversion to all tensors in an environment transition.
+
+ It iterates through the transition, finds all `torch.Tensor` objects (including those nested in
+ dictionaries like `observation`), and processes them.
+
+ Args:
+ transition: The input `EnvTransition` object.
+
+ Returns:
+ A new `EnvTransition` object with all tensors moved to the target device and dtype.
+ """
+ new_transition = transition.copy()
+ action = new_transition.get(TransitionKey.ACTION)
+
+ if action is not None and not isinstance(action, PolicyAction):
+ raise ValueError(f"If action is not None should be a PolicyAction type got {type(action)}")
+
+ simple_tensor_keys = [
+ TransitionKey.ACTION,
+ TransitionKey.REWARD,
+ TransitionKey.DONE,
+ TransitionKey.TRUNCATED,
+ ]
+
+ dict_tensor_keys = [
+ TransitionKey.OBSERVATION,
+ TransitionKey.COMPLEMENTARY_DATA,
+ ]
+
+ # Process simple, top-level tensors
+ for key in simple_tensor_keys:
+ value = transition.get(key)
+ if isinstance(value, torch.Tensor):
+ new_transition[key] = self._process_tensor(value)
+
+ # Process tensors nested within dictionaries
+ for key in dict_tensor_keys:
+ data_dict = transition.get(key)
+ if data_dict is not None:
+ new_data_dict = {
+ k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
+ for k, v in data_dict.items()
+ }
+ new_transition[key] = new_data_dict
+
+ return new_transition
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the serializable configuration of the processor.
+
+ Returns:
+ A dictionary containing the device and float_dtype settings.
+ """
+ return {"device": self.device, "float_dtype": self.float_dtype}
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Returns the input features unchanged.
+
+ Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape).
+
+ Args:
+ features: A dictionary of policy features.
+
+ Returns:
+ The original dictionary of policy features.
+ """
+ return features
diff --git a/src/lerobot/processor/factory.py b/src/lerobot/processor/factory.py
new file mode 100644
index 000000000..5a0c41072
--- /dev/null
+++ b/src/lerobot/processor/factory.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .converters import (
+ observation_to_transition,
+ robot_action_observation_to_transition,
+ transition_to_observation,
+ transition_to_robot_action,
+)
+from .core import RobotAction, RobotObservation
+from .pipeline import IdentityProcessorStep, RobotProcessorPipeline
+
+
+def make_default_teleop_action_processor() -> RobotProcessorPipeline[
+ tuple[RobotAction, RobotObservation], RobotAction
+]:
+ teleop_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[IdentityProcessorStep()],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+ )
+ return teleop_action_processor
+
+
+def make_default_robot_action_processor() -> RobotProcessorPipeline[
+ tuple[RobotAction, RobotObservation], RobotAction
+]:
+ robot_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
+ steps=[IdentityProcessorStep()],
+ to_transition=robot_action_observation_to_transition,
+ to_output=transition_to_robot_action,
+ )
+ return robot_action_processor
+
+
+def make_default_robot_observation_processor() -> RobotProcessorPipeline[RobotObservation, RobotObservation]:
+ robot_observation_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
+ steps=[IdentityProcessorStep()],
+ to_transition=observation_to_transition,
+ to_output=transition_to_observation,
+ )
+ return robot_observation_processor
+
+
+def make_default_processors():
+ teleop_action_processor = make_default_teleop_action_processor()
+ robot_action_processor = make_default_robot_action_processor()
+ robot_observation_processor = make_default_robot_observation_processor()
+ return (teleop_action_processor, robot_action_processor, robot_observation_processor)
diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py
new file mode 100644
index 000000000..8fa8cfd86
--- /dev/null
+++ b/src/lerobot/processor/gym_action_processor.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+
+from .converters import to_tensor
+from .core import EnvAction, EnvTransition, PolicyAction
+from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
+
+
+@ProcessorStepRegistry.register("torch2numpy_action_processor")
+@dataclass
+class Torch2NumpyActionProcessorStep(ActionProcessorStep):
+ """
+ Converts a PyTorch tensor action to a NumPy array.
+
+ This step is useful when the output of a policy (typically a torch.Tensor)
+ needs to be passed to an environment or component that expects a NumPy array.
+
+ Attributes:
+ squeeze_batch_dim: If True, removes the first dimension of the array
+ if it is of size 1. This is useful for converting a
+ batched action of size (1, D) to a single action of size (D,).
+ """
+
+ squeeze_batch_dim: bool = True
+
+ def action(self, action: PolicyAction) -> EnvAction:
+ if not isinstance(action, PolicyAction):
+ raise TypeError(
+ f"Expected PolicyAction or None, got {type(action).__name__}. "
+ "Use appropriate processor for non-tensor actions."
+ )
+
+ numpy_action = action.detach().cpu().numpy()
+
+ # Remove batch dimensions but preserve action dimensions.
+ # Only squeeze if there's a batch dimension (first dim == 1).
+ if (
+ self.squeeze_batch_dim
+ and numpy_action.shape
+ and len(numpy_action.shape) > 1
+ and numpy_action.shape[0] == 1
+ ):
+ numpy_action = numpy_action.squeeze(0)
+
+ return numpy_action
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@ProcessorStepRegistry.register("numpy2torch_action_processor")
+@dataclass
+class Numpy2TorchActionProcessorStep(ProcessorStep):
+ """Converts a NumPy array action to a PyTorch tensor when action is present."""
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Converts numpy action to torch tensor if action exists, otherwise passes through."""
+ from .core import TransitionKey
+
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ action = new_transition.get(TransitionKey.ACTION)
+ if action is not None:
+ if not isinstance(action, EnvAction):
+ raise TypeError(
+ f"Expected np.ndarray or None, got {type(action).__name__}. "
+ "Use appropriate processor for non-tensor actions."
+ )
+ torch_action = to_tensor(action, dtype=None) # Preserve original dtype
+ new_transition[TransitionKey.ACTION] = torch_action
+
+ return new_transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py
new file mode 100644
index 000000000..f0dbac9c3
--- /dev/null
+++ b/src/lerobot/processor/hil_processor.py
@@ -0,0 +1,596 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import time
+from dataclasses import dataclass
+from typing import Any, Protocol, TypeVar, runtime_checkable
+
+import numpy as np
+import torch
+import torchvision.transforms.functional as F # noqa: N812
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.teleoperators.teleoperator import Teleoperator
+from lerobot.teleoperators.utils import TeleopEvents
+
+from .core import EnvTransition, PolicyAction, TransitionKey
+from .pipeline import (
+ ComplementaryDataProcessorStep,
+ InfoProcessorStep,
+ ObservationProcessorStep,
+ ProcessorStep,
+ ProcessorStepRegistry,
+ TruncatedProcessorStep,
+)
+
+GRIPPER_KEY = "gripper"
+DISCRETE_PENALTY_KEY = "discrete_penalty"
+TELEOP_ACTION_KEY = "teleop_action"
+
+
+@runtime_checkable
+class HasTeleopEvents(Protocol):
+ """
+ Minimal protocol for objects that provide teleoperation events.
+
+ This protocol defines the `get_teleop_events()` method, allowing processor
+ steps to interact with teleoperators that support event-based controls
+ (like episode termination or success flagging) without needing to know the
+ teleoperator's specific class.
+ """
+
+ def get_teleop_events(self) -> dict[str, Any]:
+ """
+ Get extra control events from the teleoperator.
+
+ Returns:
+ A dictionary containing control events such as:
+ - `is_intervention`: bool - Whether the human is currently intervening.
+ - `terminate_episode`: bool - Whether to terminate the current episode.
+ - `success`: bool - Whether the episode was successful.
+ - `rerecord_episode`: bool - Whether to rerecord the episode.
+ """
+ ...
+
+
+# Type variable constrained to Teleoperator subclasses that also implement events
+TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
+
+
+def _check_teleop_with_events(teleop: Teleoperator) -> None:
+ """
+ Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
+
+ Args:
+ teleop: The teleoperator instance to check.
+
+ Raises:
+ TypeError: If the teleoperator does not have a `get_teleop_events` method.
+ """
+ if not isinstance(teleop, HasTeleopEvents):
+ raise TypeError(
+ f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
+ f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop"
+ )
+
+
+@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
+@dataclass
+class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
+ """
+ Adds the raw action from a teleoperator to the transition's complementary data.
+
+ This is useful for human-in-the-loop scenarios where the human's input needs to
+ be available to downstream processors, for example, to override a policy's action
+ during an intervention.
+
+ Attributes:
+ teleop_device: The teleoperator instance to get the action from.
+ """
+
+ teleop_device: Teleoperator
+
+ def complementary_data(self, complementary_data: dict) -> dict:
+ """
+ Retrieves the teleoperator's action and adds it to the complementary data.
+
+ Args:
+ complementary_data: The incoming complementary data dictionary.
+
+ Returns:
+ A new dictionary with the teleoperator action added under the
+ `teleop_action` key.
+ """
+ new_complementary_data = dict(complementary_data)
+ new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
+ return new_complementary_data
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@ProcessorStepRegistry.register("add_teleop_action_as_info")
+@dataclass
+class AddTeleopEventsAsInfoStep(InfoProcessorStep):
+ """
+ Adds teleoperator control events (e.g., terminate, success) to the transition's info.
+
+ This step extracts control events from teleoperators that support event-based
+ interaction, making these signals available to other parts of the system.
+
+ Attributes:
+ teleop_device: An instance of a teleoperator that implements the
+ `HasTeleopEvents` protocol.
+ """
+
+ teleop_device: TeleopWithEvents
+
+ def __post_init__(self):
+ """Validates that the provided teleoperator supports events after initialization."""
+ _check_teleop_with_events(self.teleop_device)
+
+ def info(self, info: dict) -> dict:
+ """
+ Retrieves teleoperator events and updates the info dictionary.
+
+ Args:
+ info: The incoming info dictionary.
+
+ Returns:
+ A new dictionary including the teleoperator events.
+ """
+ new_info = dict(info)
+
+ teleop_events = self.teleop_device.get_teleop_events()
+ new_info.update(teleop_events)
+ return new_info
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@ProcessorStepRegistry.register("image_crop_resize_processor")
+@dataclass
+class ImageCropResizeProcessorStep(ObservationProcessorStep):
+ """
+ Crops and/or resizes image observations.
+
+ This step iterates through all image keys in an observation dictionary and applies
+ the specified transformations. It handles device placement, moving tensors to the
+ CPU if necessary for operations not supported on certain accelerators like MPS.
+
+ Attributes:
+ crop_params_dict: A dictionary mapping image keys to cropping parameters
+ (top, left, height, width).
+ resize_size: A tuple (height, width) to resize all images to.
+ """
+
+ crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
+ resize_size: tuple[int, int] | None = None
+
+ def observation(self, observation: dict) -> dict:
+ """
+ Applies cropping and resizing to all images in the observation dictionary.
+
+ Args:
+ observation: The observation dictionary, potentially containing image tensors.
+
+ Returns:
+ A new observation dictionary with transformed images.
+ """
+ if self.resize_size is None and not self.crop_params_dict:
+ return observation
+
+ new_observation = dict(observation)
+
+ # Process all image keys in the observation
+ for key in observation:
+ if "image" not in key:
+ continue
+
+ image = observation[key]
+ device = image.device
+ # NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
+ if device.type == "mps":
+ image = image.cpu()
+ # Crop if crop params are provided for this key
+ if self.crop_params_dict is not None and key in self.crop_params_dict:
+ crop_params = self.crop_params_dict[key]
+ image = F.crop(image, *crop_params)
+ if self.resize_size is not None:
+ image = F.resize(image, self.resize_size)
+ image = image.clamp(0.0, 1.0)
+ new_observation[key] = image.to(device)
+
+ return new_observation
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the configuration of the step for serialization.
+
+ Returns:
+ A dictionary with the crop parameters and resize dimensions.
+ """
+ return {
+ "crop_params_dict": self.crop_params_dict,
+ "resize_size": self.resize_size,
+ }
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Updates the image feature shapes in the policy features dictionary if resizing is applied.
+
+ Args:
+ features: The policy features dictionary.
+
+ Returns:
+ The updated policy features dictionary with new image shapes.
+ """
+ if self.resize_size is None:
+ return features
+ for key in features[PipelineFeatureType.OBSERVATION]:
+ if "image" in key:
+ nb_channel = features[PipelineFeatureType.OBSERVATION][key].shape[0]
+ features[PipelineFeatureType.OBSERVATION][key] = PolicyFeature(
+ type=features[PipelineFeatureType.OBSERVATION][key].type,
+ shape=(nb_channel, *self.resize_size),
+ )
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register("time_limit_processor")
+class TimeLimitProcessorStep(TruncatedProcessorStep):
+ """
+ Tracks episode steps and enforces a time limit by truncating the episode.
+
+ Attributes:
+ max_episode_steps: The maximum number of steps allowed per episode.
+ current_step: The current step count for the active episode.
+ """
+
+ max_episode_steps: int
+ current_step: int = 0
+
+ def truncated(self, truncated: bool) -> bool:
+ """
+ Increments the step counter and sets the truncated flag if the time limit is reached.
+
+ Args:
+ truncated: The incoming truncated flag.
+
+ Returns:
+ True if the episode step limit is reached, otherwise the incoming value.
+ """
+ self.current_step += 1
+ if self.current_step >= self.max_episode_steps:
+ truncated = True
+ # TODO (steven): missing an else truncated = False?
+ return truncated
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the configuration of the step for serialization.
+
+ Returns:
+ A dictionary containing the `max_episode_steps`.
+ """
+ return {
+ "max_episode_steps": self.max_episode_steps,
+ }
+
+ def reset(self) -> None:
+ """Resets the step counter, typically called at the start of a new episode."""
+ self.current_step = 0
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register("gripper_penalty_processor")
+class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
+ """
+ Applies a penalty for inefficient gripper usage.
+
+ This step penalizes actions that attempt to close an already closed gripper or
+ open an already open one, based on position thresholds.
+
+ Attributes:
+ penalty: The negative reward value to apply.
+ max_gripper_pos: The maximum position value for the gripper, used for normalization.
+ """
+
+ penalty: float = -0.01
+ max_gripper_pos: float = 30.0
+
+ def complementary_data(self, complementary_data: dict) -> dict:
+ """
+ Calculates the gripper penalty and adds it to the complementary data.
+
+ Args:
+ complementary_data: The incoming complementary data, which should contain
+ raw joint positions.
+
+ Returns:
+ A new complementary data dictionary with the `discrete_penalty` key added.
+ """
+ action = self.transition.get(TransitionKey.ACTION)
+
+ raw_joint_positions = complementary_data.get("raw_joint_positions")
+ if raw_joint_positions is None:
+ return complementary_data
+
+ current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
+ if current_gripper_pos is None:
+ return complementary_data
+
+ # Gripper action is a PolicyAction at this stage
+ gripper_action = action[-1].item()
+ gripper_action_normalized = gripper_action / self.max_gripper_pos
+
+ # Normalize gripper state and action
+ gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
+
+ # Calculate penalty boolean as in original
+ gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
+ gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
+ )
+
+ gripper_penalty = self.penalty * int(gripper_penalty_bool)
+
+ # Create new complementary data with penalty info
+ new_complementary_data = dict(complementary_data)
+ new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
+
+ return new_complementary_data
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the configuration of the step for serialization.
+
+ Returns:
+ A dictionary containing the penalty value and max gripper position.
+ """
+ return {
+ "penalty": self.penalty,
+ "max_gripper_pos": self.max_gripper_pos,
+ }
+
+ def reset(self) -> None:
+ """Resets the processor's internal state."""
+ pass
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register("intervention_action_processor")
+class InterventionActionProcessorStep(ProcessorStep):
+ """
+ Handles human intervention, overriding policy actions and managing episode termination.
+
+ When an intervention is detected (via teleoperator events in the `info` dict),
+ this step replaces the policy's action with the human's teleoperated action.
+ It also processes signals to terminate the episode or flag success.
+
+ Attributes:
+ use_gripper: Whether to include the gripper in the teleoperated action.
+ terminate_on_success: If True, automatically sets the `done` flag when a
+ `success` event is received.
+ """
+
+ use_gripper: bool = False
+ terminate_on_success: bool = True
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """
+ Processes the transition to handle interventions.
+
+ Args:
+ transition: The incoming environment transition.
+
+ Returns:
+ The modified transition, potentially with an overridden action, updated
+ reward, and termination status.
+ """
+ action = transition.get(TransitionKey.ACTION)
+ if not isinstance(action, PolicyAction):
+ raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
+
+ # Get intervention signals from complementary data
+ info = transition.get(TransitionKey.INFO, {})
+ complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
+ teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
+ is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
+ terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
+ success = info.get(TeleopEvents.SUCCESS, False)
+ rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
+
+ new_transition = transition.copy()
+
+ # Override action if intervention is active
+ if is_intervention and teleop_action is not None:
+ if isinstance(teleop_action, dict):
+ # Convert teleop_action dict to tensor format
+ action_list = [
+ teleop_action.get("delta_x", 0.0),
+ teleop_action.get("delta_y", 0.0),
+ teleop_action.get("delta_z", 0.0),
+ ]
+ if self.use_gripper:
+ action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
+ elif isinstance(teleop_action, np.ndarray):
+ action_list = teleop_action.tolist()
+ else:
+ action_list = teleop_action
+
+ teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
+ new_transition[TransitionKey.ACTION] = teleop_action_tensor
+
+ # Handle episode termination
+ new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
+ self.terminate_on_success and success
+ )
+ new_transition[TransitionKey.REWARD] = float(success)
+
+ # Update info with intervention metadata
+ info = new_transition.get(TransitionKey.INFO, {})
+ info[TeleopEvents.IS_INTERVENTION] = is_intervention
+ info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
+ info[TeleopEvents.SUCCESS] = success
+ new_transition[TransitionKey.INFO] = info
+
+ # Update complementary data with teleop action
+ complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
+ complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
+ new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
+
+ return new_transition
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the configuration of the step for serialization.
+
+ Returns:
+ A dictionary containing the step's configuration attributes.
+ """
+ return {
+ "use_gripper": self.use_gripper,
+ "terminate_on_success": self.terminate_on_success,
+ }
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register("reward_classifier_processor")
+class RewardClassifierProcessorStep(ProcessorStep):
+ """
+ Applies a pretrained reward classifier to image observations to predict success.
+
+ This step uses a model to determine if the current state is successful, updating
+ the reward and potentially terminating the episode.
+
+ Attributes:
+ pretrained_path: Path to the pretrained reward classifier model.
+ device: The device to run the classifier on.
+ success_threshold: The probability threshold to consider a prediction as successful.
+ success_reward: The reward value to assign on success.
+ terminate_on_success: If True, terminates the episode upon successful classification.
+ reward_classifier: The loaded classifier model instance.
+ """
+
+ pretrained_path: str | None = None
+ device: str = "cpu"
+ success_threshold: float = 0.5
+ success_reward: float = 1.0
+ terminate_on_success: bool = True
+
+ reward_classifier: Any = None
+
+ def __post_init__(self):
+ """Initializes the reward classifier model after the dataclass is created."""
+ if self.pretrained_path is not None:
+ from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
+
+ self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
+ self.reward_classifier.to(self.device)
+ self.reward_classifier.eval()
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """
+ Processes a transition, applying the reward classifier to its image observations.
+
+ Args:
+ transition: The incoming environment transition.
+
+ Returns:
+ The modified transition with an updated reward and done flag based on the
+ classifier's prediction.
+ """
+ new_transition = transition.copy()
+ observation = new_transition.get(TransitionKey.OBSERVATION)
+ if observation is None or self.reward_classifier is None:
+ return new_transition
+
+ # Extract images from observation
+ images = {key: value for key, value in observation.items() if "image" in key}
+
+ if not images:
+ return new_transition
+
+ # Run reward classifier
+ start_time = time.perf_counter()
+ with torch.inference_mode():
+ success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
+
+ classifier_frequency = 1 / (time.perf_counter() - start_time)
+
+ # Calculate reward and termination
+ reward = new_transition.get(TransitionKey.REWARD, 0.0)
+ terminated = new_transition.get(TransitionKey.DONE, False)
+
+ if math.isclose(success, 1, abs_tol=1e-2):
+ reward = self.success_reward
+ if self.terminate_on_success:
+ terminated = True
+
+ # Update transition
+ new_transition[TransitionKey.REWARD] = reward
+ new_transition[TransitionKey.DONE] = terminated
+
+ # Update info with classifier frequency
+ info = new_transition.get(TransitionKey.INFO, {})
+ info["reward_classifier_frequency"] = classifier_frequency
+ new_transition[TransitionKey.INFO] = info
+
+ return new_transition
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the configuration of the step for serialization.
+
+ Returns:
+ A dictionary containing the step's configuration attributes.
+ """
+ return {
+ "device": self.device,
+ "success_threshold": self.success_threshold,
+ "success_reward": self.success_reward,
+ "terminate_on_success": self.terminate_on_success,
+ }
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/processor/joint_observations_processor.py
new file mode 100644
index 000000000..2fbcc7c46
--- /dev/null
+++ b/src/lerobot/processor/joint_observations_processor.py
@@ -0,0 +1,211 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.processor.pipeline import (
+ ObservationProcessorStep,
+ ProcessorStepRegistry,
+)
+from lerobot.robots import Robot
+from lerobot.utils.constants import OBS_STATE
+
+
+@dataclass
+@ProcessorStepRegistry.register("joint_velocity_processor")
+class JointVelocityProcessorStep(ObservationProcessorStep):
+ """
+ Calculates and appends joint velocity information to the observation state.
+
+ This step computes the velocity of each joint by calculating the finite
+ difference between the current and the last observed joint positions. The
+ resulting velocity vector is then concatenated to the original state vector.
+
+ Attributes:
+ dt: The time step (delta time) in seconds between observations, used for
+ calculating velocity.
+ last_joint_positions: Stores the joint positions from the previous step
+ to enable velocity calculation.
+ """
+
+ dt: float = 0.1
+
+ last_joint_positions: torch.Tensor | None = None
+
+ def observation(self, observation: dict) -> dict:
+ """
+ Computes joint velocities and adds them to the observation state.
+
+ Args:
+ observation: The input observation dictionary, expected to contain
+ an `observation.state` key with joint positions.
+
+ Returns:
+ A new observation dictionary with the `observation.state` tensor
+ extended to include joint velocities.
+
+ Raises:
+ ValueError: If `observation.state` is not found in the observation.
+ """
+ # Get current joint positions (assuming they're in observation.state)
+ current_positions = observation.get(OBS_STATE)
+ if current_positions is None:
+ raise ValueError(f"{OBS_STATE} is not in observation")
+
+ # Initialize last joint positions if not already set
+ if self.last_joint_positions is None:
+ self.last_joint_positions = current_positions.clone()
+ joint_velocities = torch.zeros_like(current_positions)
+ else:
+ # Compute velocities
+ joint_velocities = (current_positions - self.last_joint_positions) / self.dt
+
+ self.last_joint_positions = current_positions.clone()
+
+ # Extend observation with velocities
+ extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
+
+ # Create new observation dict
+ new_observation = dict(observation)
+ new_observation[OBS_STATE] = extended_state
+
+ return new_observation
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the configuration of the step for serialization.
+
+ Returns:
+ A dictionary containing the time step `dt`.
+ """
+ return {
+ "dt": self.dt,
+ }
+
+ def reset(self) -> None:
+ """Resets the internal state, clearing the last known joint positions."""
+ self.last_joint_positions = None
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Updates the `observation.state` feature to reflect the added velocities.
+
+ This method doubles the size of the first dimension of the `observation.state`
+ shape to account for the concatenation of position and velocity vectors.
+
+ Args:
+ features: The policy features dictionary.
+
+ Returns:
+ The updated policy features dictionary.
+ """
+ if OBS_STATE in features[PipelineFeatureType.OBSERVATION]:
+ original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE]
+ # Double the shape to account for positions + velocities
+ new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
+
+ features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature(
+ type=original_feature.type, shape=new_shape
+ )
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register("current_processor")
+class MotorCurrentProcessorStep(ObservationProcessorStep):
+ """
+ Reads motor currents from a robot and appends them to the observation state.
+
+ This step queries the robot's hardware interface to get the present current
+ for each motor and concatenates this information to the existing state vector.
+
+ Attributes:
+ robot: An instance of a `lerobot` Robot class that provides access to
+ the hardware bus.
+ """
+
+ robot: Robot | None = None
+
+ def observation(self, observation: dict) -> dict:
+ """
+ Fetches motor currents and adds them to the observation state.
+
+ Args:
+ observation: The input observation dictionary.
+
+ Returns:
+ A new observation dictionary with the `observation.state` tensor
+ extended to include motor currents.
+
+ Raises:
+ ValueError: If the `robot` attribute has not been set.
+ """
+ # Get current values from robot state
+ if self.robot is None:
+ raise ValueError("Robot is not set")
+
+ present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
+ motor_currents = torch.tensor(
+ [present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
+ dtype=torch.float32,
+ ).unsqueeze(0)
+
+ current_state = observation.get(OBS_STATE)
+ if current_state is None:
+ return observation
+
+ extended_state = torch.cat([current_state, motor_currents], dim=-1)
+
+ # Create new observation dict
+ new_observation = dict(observation)
+ new_observation[OBS_STATE] = extended_state
+
+ return new_observation
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Updates the `observation.state` feature to reflect the added motor currents.
+
+ This method increases the size of the first dimension of the `observation.state`
+ shape by the number of motors in the robot.
+
+ Args:
+ features: The policy features dictionary.
+
+ Returns:
+ The updated policy features dictionary.
+ """
+ if OBS_STATE in features[PipelineFeatureType.OBSERVATION] and self.robot is not None:
+ original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE]
+ # Add motor current dimensions to the original state shape
+ num_motors = 0
+ if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
+ num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
+
+ if num_motors > 0:
+ new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
+ features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature(
+ type=original_feature.type, shape=new_shape
+ )
+ return features
diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py
new file mode 100644
index 000000000..525b7431c
--- /dev/null
+++ b/src/lerobot/processor/migrate_policy_normalization.py
@@ -0,0 +1,769 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+A generic script to migrate LeRobot policies with built-in normalization layers to the new
+pipeline-based processor system.
+
+This script performs the following steps:
+1. Loads a pretrained policy model and its configuration from a local path or the
+ Hugging Face Hub.
+2. Scans the model's state dictionary to extract normalization statistics (e.g., mean,
+ std, min, max) for all features.
+3. Creates two new processor pipelines:
+ - A preprocessor that normalizes inputs (observations) and outputs (actions).
+ - A postprocessor that unnormalizes outputs (actions) for inference.
+4. Removes the original normalization layers from the model's state dictionary,
+ creating a "clean" model.
+5. Saves the new clean model, the preprocessor, the postprocessor, and a generated
+ model card to a new directory.
+6. Optionally pushes all the new artifacts to the Hugging Face Hub.
+
+Usage:
+ python src/lerobot/processor/migrate_policy_normalization.py \
+ --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
+ --push-to-hub \
+ --branch main
+
+Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config`
+factory functions from `lerobot.policies.factory` to create processors and configurations,
+ensuring consistency with the current codebase.
+
+The script extracts normalization statistics from the old model's state_dict, creates clean
+processor pipelines using the factory functions, and saves a migrated model that is compatible
+with the new PolicyProcessorPipeline architecture.
+"""
+
+import argparse
+import json
+import os
+from pathlib import Path
+from typing import Any
+
+import torch
+from huggingface_hub import HfApi, hf_hub_download
+from safetensors.torch import load_file as load_safetensors
+
+from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
+from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
+from lerobot.utils.constants import ACTION
+
+
+def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
+ """
+ Scans a model's state_dict to find and extract normalization statistics.
+
+ This function identifies keys corresponding to normalization layers (e.g., those
+ for mean, std, min, max) based on a set of predefined patterns and organizes
+ them into a nested dictionary.
+
+ Args:
+ state_dict: The state dictionary of a pretrained policy model.
+
+ Returns:
+ A nested dictionary where outer keys are feature names (e.g.,
+ 'observation.state') and inner keys are statistic types ('mean', 'std'),
+ mapping to their corresponding tensor values.
+ """
+ stats = {}
+
+ # Define patterns to match and their prefixes to remove
+ normalization_patterns = [
+ "normalize_inputs.buffer_",
+ "unnormalize_outputs.buffer_",
+ "normalize_targets.buffer_",
+ "normalize.", # Must come after normalize_* patterns
+ "unnormalize.", # Must come after unnormalize_* patterns
+ "input_normalizer.",
+ "output_normalizer.",
+ "normalalize_inputs.",
+ "unnormalize_outputs.",
+ "normalize_targets.",
+ "unnormalize_targets.",
+ ]
+
+ # Process each key in state_dict
+ for key, tensor in state_dict.items():
+ # Try each pattern
+ for pattern in normalization_patterns:
+ if key.startswith(pattern):
+ # Extract the remaining part after the pattern
+ remaining = key[len(pattern) :]
+ parts = remaining.split(".")
+
+ # Need at least feature name and stat type
+ if len(parts) >= 2:
+ # Last part is the stat type (mean, std, min, max, etc.)
+ stat_type = parts[-1]
+ # Everything else is the feature name
+ feature_name = ".".join(parts[:-1]).replace("_", ".")
+
+ # Add to stats
+ if feature_name not in stats:
+ stats[feature_name] = {}
+ stats[feature_name][stat_type] = tensor.clone()
+
+ # Only process the first matching pattern
+ break
+
+ return stats
+
+
+def detect_features_and_norm_modes(
+ config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
+) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
+ """
+ Infers policy features and normalization modes from the model config and stats.
+
+ This function first attempts to find feature definitions and normalization
+ mappings directly from the policy's configuration file. If this information is
+ not present, it infers it from the extracted normalization statistics, using
+ tensor shapes to determine feature shapes and the presence of specific stat
+ keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode.
+ It applies sensible defaults if inference is not possible.
+
+ Args:
+ config: The policy's configuration dictionary from `config.json`.
+ stats: The normalization statistics extracted from the model's state_dict.
+
+ Returns:
+ A tuple containing:
+ - A dictionary mapping feature names to `PolicyFeature` objects.
+ - A dictionary mapping `FeatureType` enums to `NormalizationMode` enums.
+ """
+ features = {}
+ norm_modes = {}
+
+ # First, check if there's a normalization_mapping in the config
+ if "normalization_mapping" in config:
+ print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
+ # Extract normalization modes from config
+ for feature_type_str, mode_str in config["normalization_mapping"].items():
+ # Convert string to FeatureType enum
+ try:
+ if feature_type_str == "VISUAL":
+ feature_type = FeatureType.VISUAL
+ elif feature_type_str == "STATE":
+ feature_type = FeatureType.STATE
+ elif feature_type_str == "ACTION":
+ feature_type = FeatureType.ACTION
+ else:
+ print(f"Warning: Unknown feature type '{feature_type_str}', skipping")
+ continue
+ except (AttributeError, ValueError):
+ print(f"Warning: Could not parse feature type '{feature_type_str}', skipping")
+ continue
+
+ # Convert string to NormalizationMode enum
+ try:
+ if mode_str == "MEAN_STD":
+ mode = NormalizationMode.MEAN_STD
+ elif mode_str == "MIN_MAX":
+ mode = NormalizationMode.MIN_MAX
+ elif mode_str == "IDENTITY":
+ mode = NormalizationMode.IDENTITY
+ else:
+ print(
+ f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'"
+ )
+ continue
+ except (AttributeError, ValueError):
+ print(f"Warning: Could not parse normalization mode '{mode_str}', skipping")
+ continue
+
+ norm_modes[feature_type] = mode
+
+ # Try to extract from config
+ if "features" in config:
+ for key, feature_config in config["features"].items():
+ shape = feature_config.get("shape", feature_config.get("dim"))
+ shape = (shape,) if isinstance(shape, int) else tuple(shape)
+
+ # Determine feature type
+ if "image" in key or "visual" in key:
+ feature_type = FeatureType.VISUAL
+ elif "state" in key:
+ feature_type = FeatureType.STATE
+ elif ACTION in key:
+ feature_type = FeatureType.ACTION
+ else:
+ feature_type = FeatureType.STATE # Default
+
+ features[key] = PolicyFeature(feature_type, shape)
+
+ # If no features in config, infer from stats
+ if not features:
+ for key, stat_dict in stats.items():
+ # Get shape from any stat tensor
+ tensor = next(iter(stat_dict.values()))
+ shape = tuple(tensor.shape)
+
+ # Determine feature type based on key
+ if "image" in key or "visual" in key or "pixels" in key:
+ feature_type = FeatureType.VISUAL
+ elif "state" in key or "joint" in key or "position" in key:
+ feature_type = FeatureType.STATE
+ elif ACTION in key:
+ feature_type = FeatureType.ACTION
+ else:
+ feature_type = FeatureType.STATE
+
+ features[key] = PolicyFeature(feature_type, shape)
+
+ # If normalization modes weren't in config, determine based on available stats
+ if not norm_modes:
+ for key, stat_dict in stats.items():
+ if key in features:
+ if "mean" in stat_dict and "std" in stat_dict:
+ feature_type = features[key].type
+ if feature_type not in norm_modes:
+ norm_modes[feature_type] = NormalizationMode.MEAN_STD
+ elif "min" in stat_dict and "max" in stat_dict:
+ feature_type = features[key].type
+ if feature_type not in norm_modes:
+ norm_modes[feature_type] = NormalizationMode.MIN_MAX
+
+ # Default normalization modes if not detected
+ if FeatureType.VISUAL not in norm_modes:
+ norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
+ if FeatureType.STATE not in norm_modes:
+ norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
+ if FeatureType.ACTION not in norm_modes:
+ norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
+
+ return features, norm_modes
+
+
+def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Creates a new state_dict with all normalization-related layers removed.
+
+ This function filters the original state dictionary, excluding any keys that
+ match a set of predefined patterns associated with normalization modules.
+
+ Args:
+ state_dict: The original model state dictionary.
+
+ Returns:
+ A new state dictionary containing only the core model weights, without
+ any normalization parameters.
+ """
+ new_state_dict = {}
+
+ # Patterns to remove
+ remove_patterns = [
+ "normalize_inputs.",
+ "unnormalize_outputs.",
+ "normalize_targets.", # Added pattern for target normalization
+ "normalize.",
+ "unnormalize.",
+ "input_normalizer.",
+ "output_normalizer.",
+ "normalizer.",
+ ]
+
+ for key, tensor in state_dict.items():
+ should_remove = any(pattern in key for pattern in remove_patterns)
+ if not should_remove:
+ new_state_dict[key] = tensor
+
+ return new_state_dict
+
+
+def clean_state_dict(
+ state_dict: dict[str, torch.Tensor], remove_str: str = "._orig_mod"
+) -> dict[str, torch.Tensor]:
+ """
+ Remove a substring (e.g. '._orig_mod') from all keys in a state dict.
+
+ Args:
+ state_dict (dict): The original state dict.
+ remove_str (str): The substring to remove from the keys.
+
+ Returns:
+ dict: A new state dict with cleaned keys.
+ """
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ new_k = k.replace(remove_str, "")
+ new_state_dict[new_k] = v
+ return new_state_dict
+
+
+def load_state_dict_with_missing_key_handling(
+ policy: torch.nn.Module,
+ state_dict: dict[str, torch.Tensor],
+ policy_type: str,
+ known_missing_keys_whitelist: dict[str, list[str]],
+) -> list[str]:
+ """
+ Load state dict into policy with graceful handling of missing keys.
+
+ This function loads the state dict with strict=False, filters out whitelisted
+ missing keys, and provides detailed reporting about any issues found.
+
+ Args:
+ policy: The policy model to load the state dict into.
+ state_dict: The cleaned state dictionary to load.
+ policy_type: The type of policy (used for whitelist lookup).
+ known_missing_keys_whitelist: Dictionary mapping policy types to lists of
+ known acceptable missing keys.
+
+ Returns:
+ List of problematic missing keys that weren't in the whitelist.
+ """
+ # Load the cleaned state dict with strict=False to capture missing/unexpected keys
+ load_result = policy.load_state_dict(state_dict, strict=False)
+
+ # Check for missing keys
+ missing_keys = load_result.missing_keys
+ unexpected_keys = load_result.unexpected_keys
+
+ # Filter out whitelisted missing keys
+ policy_type_lower = policy_type.lower()
+ whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, [])
+ problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys]
+
+ if missing_keys:
+ if problematic_missing_keys:
+ print(f"WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:")
+ for key in problematic_missing_keys:
+ print(f" - {key}")
+
+ if len(missing_keys) > len(problematic_missing_keys):
+ whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys]
+ print(f"INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):")
+ for key in whitelisted_missing:
+ print(f" - {key}")
+
+ if unexpected_keys:
+ print(f"WARNING: Found {len(unexpected_keys)} unexpected keys:")
+ for key in unexpected_keys:
+ print(f" - {key}")
+
+ if not missing_keys and not unexpected_keys:
+ print("Successfully loaded cleaned state dict into policy model (all keys matched)")
+ else:
+ print("State dict loaded with some missing/unexpected keys (see details above)")
+
+ return problematic_missing_keys
+
+
+def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
+ """
+ Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
+
+ Args:
+ features_dict: The feature dictionary in the old format, where values are
+ simple dictionaries (e.g., `{"shape": [7]}`).
+
+ Returns:
+ A dictionary mapping feature names to `PolicyFeature` dataclass objects.
+ """
+ converted_features = {}
+
+ for key, feature_dict in features_dict.items():
+ # Determine feature type based on key
+ if "image" in key or "visual" in key:
+ feature_type = FeatureType.VISUAL
+ elif "state" in key:
+ feature_type = FeatureType.STATE
+ elif ACTION in key:
+ feature_type = FeatureType.ACTION
+ else:
+ feature_type = FeatureType.STATE
+
+ # Get shape from feature dict
+ shape = feature_dict.get("shape", feature_dict.get("dim"))
+ shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else ()
+
+ converted_features[key] = PolicyFeature(feature_type, shape)
+
+ return converted_features
+
+
+def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None:
+ """
+ Display final migration summary with warnings about problematic missing keys.
+
+ Args:
+ problematic_missing_keys: List of missing keys that weren't in the whitelist.
+ """
+ if not problematic_missing_keys:
+ return
+
+ print("\n" + "=" * 60)
+ print("IMPORTANT: MIGRATION COMPLETED WITH WARNINGS")
+ print("=" * 60)
+ print(
+ f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:"
+ )
+ print()
+ for key in problematic_missing_keys:
+ print(f" - {key}")
+ print()
+ print("These missing keys may indicate:")
+ print(" • The model architecture has changed")
+ print(" • Some components were not properly saved in the original model")
+ print(" • The migration script needs to be updated for this policy type")
+ print()
+ print("What to do next:")
+ print(" 1. Test your migrated model carefully to ensure it works as expected")
+ print(" 2. If you encounter issues, please open an issue at:")
+ print(" https://github.com/huggingface/lerobot/issues")
+ print(" 3. Include this migration log and the missing keys listed above")
+ print()
+ print("If the model works correctly despite these warnings, the missing keys")
+ print("might be expected for your policy type and can be added to the whitelist.")
+ print("=" * 60)
+
+
+def load_model_from_hub(
+ repo_id: str, revision: str | None = None
+) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]:
+ """
+ Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
+
+ Args:
+ repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha').
+ revision: The specific git revision (branch, tag, or commit hash) to use.
+
+ Returns:
+ A tuple containing the model's state dictionary, the policy configuration,
+ and the training configuration (None if train_config.json is not found).
+ """
+ # Download files.
+ safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
+
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
+
+ # Load state_dict
+ state_dict = load_safetensors(safetensors_path)
+
+ # Load config
+ with open(config_path) as f:
+ config = json.load(f)
+
+ # Try to load train_config (optional)
+ train_config = None
+ try:
+ train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
+ with open(train_config_path) as f:
+ train_config = json.load(f)
+ except FileNotFoundError:
+ print("train_config.json not found - continuing without training configuration")
+
+ return state_dict, config, train_config
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Migrate policy models with normalization layers to new pipeline system"
+ )
+ parser.add_argument(
+ "--pretrained-path",
+ type=str,
+ required=True,
+ help="Path to pretrained model (hub repo or local directory)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=None,
+ help="Output directory for migrated model (default: same as pretrained-path)",
+ )
+ parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
+ parser.add_argument(
+ "--hub-repo-id",
+ type=str,
+ default=None,
+ help="Hub repository ID for pushing (default: same as pretrained-path)",
+ )
+ parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
+ parser.add_argument("--private", action="store_true", help="Make the hub repository private")
+ parser.add_argument(
+ "--branch",
+ type=str,
+ default=None,
+ help="Git branch to use when pushing to hub. If specified, a PR will be created automatically (default: push directly to main)",
+ )
+
+ args = parser.parse_args()
+
+ # Load model and config
+ print(f"Loading model from {args.pretrained_path}...")
+ if os.path.isdir(args.pretrained_path):
+ # Local directory
+ state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
+ with open(os.path.join(args.pretrained_path, "config.json")) as f:
+ config = json.load(f)
+
+ # Try to load train_config (optional)
+ train_config = None
+ train_config_path = os.path.join(args.pretrained_path, "train_config.json")
+ if os.path.exists(train_config_path):
+ with open(train_config_path) as f:
+ train_config = json.load(f)
+ else:
+ print("train_config.json not found - continuing without training configuration")
+ else:
+ # Hub repository
+ state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
+
+ # Extract normalization statistics
+ print("Extracting normalization statistics...")
+ stats = extract_normalization_stats(state_dict)
+
+ print(f"Found normalization statistics for: {list(stats.keys())}")
+
+ # Detect input features and normalization modes
+ print("Detecting features and normalization modes...")
+ features, norm_map = detect_features_and_norm_modes(config, stats)
+
+ print(f"Detected features: {list(features.keys())}")
+ print(f"Normalization modes: {norm_map}")
+
+ # Remove normalization layers from state_dict
+ print("Removing normalization layers from model...")
+ new_state_dict = remove_normalization_layers(state_dict)
+ new_state_dict = clean_state_dict(new_state_dict, remove_str="._orig_mod")
+
+ removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
+ if removed_keys:
+ print(f"Removed {len(removed_keys)} normalization layer keys")
+
+ # Determine output path
+ if args.output_dir:
+ output_dir = Path(args.output_dir)
+ else:
+ if os.path.isdir(args.pretrained_path):
+ output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
+ else:
+ output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Extract policy type from config
+ if "type" not in config:
+ raise ValueError("Policy type not found in config.json. The config must contain a 'type' field.")
+
+ policy_type = config["type"]
+ print(f"Detected policy type: {policy_type}")
+
+ # Clean up config - remove fields that shouldn't be passed to config constructor
+ cleaned_config = dict(config)
+
+ # Remove fields that are not part of the config class constructors
+ fields_to_remove = ["normalization_mapping", "type"]
+ for field in fields_to_remove:
+ if field in cleaned_config:
+ print(f"Removing '{field}' field from config")
+ del cleaned_config[field]
+
+ # Convert input_features and output_features to PolicyFeature objects if they exist
+ if "input_features" in cleaned_config:
+ cleaned_config["input_features"] = convert_features_to_policy_features(
+ cleaned_config["input_features"]
+ )
+ if "output_features" in cleaned_config:
+ cleaned_config["output_features"] = convert_features_to_policy_features(
+ cleaned_config["output_features"]
+ )
+
+ # Add normalization mapping to config
+ cleaned_config["normalization_mapping"] = norm_map
+
+ # Create policy configuration using the factory
+ print(f"Creating {policy_type} policy configuration...")
+ policy_config = make_policy_config(policy_type, **cleaned_config)
+
+ # Create policy instance using the factory
+ print(f"Instantiating {policy_type} policy...")
+ policy_class = get_policy_class(policy_type)
+ policy = policy_class(policy_config)
+
+ # Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types
+ known_missing_keys_whitelist = {
+ "pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"],
+ # Add other policy types and their known missing keys here as needed
+ }
+
+ # Load state dict with graceful missing key handling
+ problematic_missing_keys = load_state_dict_with_missing_key_handling(
+ policy=policy,
+ state_dict=new_state_dict,
+ policy_type=policy_type,
+ known_missing_keys_whitelist=known_missing_keys_whitelist,
+ )
+ policy.to(torch.float32)
+ # Create preprocessor and postprocessor using the factory
+ print("Creating preprocessor and postprocessor using make_pre_post_processors...")
+ preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
+
+ # Determine hub repo ID if pushing to hub
+ hub_repo_id = None
+ if args.push_to_hub:
+ if args.hub_repo_id:
+ hub_repo_id = args.hub_repo_id
+ else:
+ if not os.path.isdir(args.pretrained_path):
+ # Use same repo with "_migrated" suffix
+ hub_repo_id = f"{args.pretrained_path}_migrated"
+ else:
+ raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
+
+ # Save all components to local directory first
+ print(f"Saving preprocessor to {output_dir}...")
+ preprocessor.save_pretrained(output_dir)
+
+ print(f"Saving postprocessor to {output_dir}...")
+ postprocessor.save_pretrained(output_dir)
+
+ print(f"Saving model to {output_dir}...")
+ policy.save_pretrained(output_dir)
+
+ # Generate and save model card
+ print("Generating model card...")
+ # Get metadata from original config
+ dataset_repo_id = "unknown"
+ if train_config is not None:
+ dataset_repo_id = train_config.get("repo_id", "unknown")
+ license = config.get("license", "apache-2.0")
+
+ tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
+ tags = set(tags).union({"robotics", "lerobot", policy_type})
+ tags = list(tags)
+
+ # Generate model card
+ card = policy.generate_model_card(
+ dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
+ )
+
+ # Save model card locally
+ card.save(str(output_dir / "README.md"))
+ print(f"Model card saved to {output_dir / 'README.md'}")
+ # Push all files to hub in a single operation if requested
+ if args.push_to_hub and hub_repo_id:
+ api = HfApi()
+
+ # Determine if we should create a PR (automatically if branch is specified)
+ create_pr = args.branch is not None
+ target_location = f"branch '{args.branch}'" if args.branch else "main branch"
+
+ print(f"Pushing all migrated files to {hub_repo_id} on {target_location}...")
+
+ # Upload all files in a single commit with automatic PR creation if branch specified
+ commit_message = "Migrate policy to PolicyProcessorPipeline system"
+ commit_description = None
+
+ if create_pr:
+ # Separate commit description for PR body
+ commit_description = """**Automated Policy Migration to PolicyProcessorPipeline**
+
+This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture.
+
+## What Changed
+
+### **New Architecture - PolicyProcessorPipeline**
+Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides:
+- **Modularity**: Separate preprocessing and postprocessing pipelines
+- **Flexibility**: Easy to swap, configure, and debug processing steps
+- **Compatibility**: Works with the latest LeRobot ecosystem
+
+### **Normalization Extraction**
+We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers:
+- **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*`
+- **Statistics preserved**: Mean, std, min, max values for all features
+- **Clean model**: State dict now contains only core model weights
+
+### **Files Added**
+- **preprocessor_config.json**: Configuration for input preprocessing pipeline
+- **postprocessor_config.json**: Configuration for output postprocessing pipeline
+- **model.safetensors**: Clean model weights without normalization layers
+- **config.json**: Updated model configuration
+- **train_config.json**: Training configuration
+- **README.md**: Updated model card with migration information
+
+### **Benefits**
+- **Backward Compatible**: Your model behavior remains identical
+- **Future Ready**: Compatible with latest LeRobot features and updates
+- **Debuggable**: Easy to inspect and modify processing steps
+- **Portable**: Processors can be shared and reused across models
+
+### **Usage**
+```python
+# Load your migrated model
+from lerobot.policies import get_policy_class
+from lerobot.processor import PolicyProcessorPipeline
+
+# The preprocessor and postprocessor are now external
+preprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="preprocessor_config.json")
+postprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="postprocessor_config.json")
+policy = get_policy_class("your-policy-type").from_pretrained("your-model-repo")
+
+# Process data through the pipeline
+processed_batch = preprocessor(raw_batch)
+action = policy(processed_batch)
+final_action = postprocessor(action)
+```
+
+*Generated automatically by the LeRobot policy migration script*"""
+
+ upload_kwargs = {
+ "repo_id": hub_repo_id,
+ "folder_path": output_dir,
+ "repo_type": "model",
+ "commit_message": commit_message,
+ "revision": args.branch,
+ "create_pr": create_pr,
+ "allow_patterns": ["*.json", "*.safetensors", "*.md"],
+ "ignore_patterns": ["*.tmp", "*.log"],
+ }
+
+ # Add commit_description for PR body if creating PR
+ if create_pr and commit_description:
+ upload_kwargs["commit_description"] = commit_description
+
+ api.upload_folder(**upload_kwargs)
+
+ if create_pr:
+ print("All files pushed and pull request created successfully!")
+ else:
+ print("All files pushed to main branch successfully!")
+
+ print("\nMigration complete!")
+ print(f"Migrated model saved to: {output_dir}")
+ if args.push_to_hub and hub_repo_id:
+ if args.branch:
+ print(
+ f"Successfully pushed all files to branch '{args.branch}' and created PR on https://huggingface.co/{hub_repo_id}"
+ )
+ else:
+ print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
+ if args.branch:
+ print(f"\nView the branch at: https://huggingface.co/{hub_repo_id}/tree/{args.branch}")
+ print(
+ f"View the PR at: https://huggingface.co/{hub_repo_id}/discussions (look for the most recent PR)"
+ )
+ else:
+ print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
+
+ # Display final summary about any problematic missing keys
+ display_migration_summary_with_warnings(problematic_missing_keys)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py
new file mode 100644
index 000000000..368c9b270
--- /dev/null
+++ b/src/lerobot/processor/normalize_processor.py
@@ -0,0 +1,560 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from copy import deepcopy
+from dataclasses import dataclass, field
+from typing import Any
+
+import torch
+from torch import Tensor
+
+from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.utils.constants import ACTION
+
+from .converters import from_tensor_to_numpy, to_tensor
+from .core import EnvTransition, PolicyAction, TransitionKey
+from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
+
+
+@dataclass
+class _NormalizationMixin:
+ """
+ A mixin class providing core functionality for normalization and unnormalization.
+
+ This class manages normalization statistics (`stats`), converts them to tensors for
+ efficient computation, handles device placement, and implements the logic for
+ applying normalization transformations (mean/std and min/max). It is designed to
+ be inherited by concrete `ProcessorStep` implementations and should not be used
+ directly.
+
+ **Stats Override Preservation:**
+ When stats are explicitly provided during construction (e.g., via overrides in
+ `DataProcessorPipeline.from_pretrained()`), they are preserved even when
+ `load_state_dict()` is called. This allows users to override normalization
+ statistics from saved models while keeping the rest of the model state intact.
+
+ Examples:
+ ```python
+ # Common use case: Override with dataset stats
+ from lerobot.datasets import LeRobotDataset
+
+ dataset = LeRobotDataset("my_dataset")
+ pipeline = DataProcessorPipeline.from_pretrained(
+ "model_path", overrides={"normalizer_processor": {"stats": dataset.meta.stats}}
+ )
+ # dataset.meta.stats will be used, not the stats from the saved model
+
+ # Custom stats override
+ custom_stats = {"action": {"mean": [0.0], "std": [1.0]}}
+ pipeline = DataProcessorPipeline.from_pretrained(
+ "model_path", overrides={"normalizer_processor": {"stats": custom_stats}}
+ )
+ ```
+
+ Attributes:
+ features: A dictionary mapping feature names to `PolicyFeature` objects, defining
+ the data structure to be processed.
+ norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying
+ which normalization method to use for each type of feature.
+ stats: A dictionary containing the normalization statistics (e.g., mean, std,
+ min, max) for each feature.
+ device: The PyTorch device on which to store and perform tensor operations.
+ eps: A small epsilon value to prevent division by zero in normalization
+ calculations.
+ normalize_observation_keys: An optional set of keys to selectively apply
+ normalization to specific observation features.
+ _tensor_stats: An internal dictionary holding the normalization statistics as
+ PyTorch tensors.
+ _stats_explicitly_provided: Internal flag tracking whether stats were explicitly
+ provided during construction (used for override preservation).
+ """
+
+ features: dict[str, PolicyFeature]
+ norm_map: dict[FeatureType, NormalizationMode]
+ stats: dict[str, dict[str, Any]] | None = None
+ device: torch.device | str | None = None
+ dtype: torch.dtype | None = None
+ eps: float = 1e-8
+ normalize_observation_keys: set[str] | None = None
+
+ _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
+ _stats_explicitly_provided: bool = field(default=False, init=False, repr=False)
+
+ def __post_init__(self):
+ """
+ Initializes the mixin after dataclass construction.
+
+ This method handles the robust deserialization of `features` and `norm_map`
+ from JSON-compatible formats (where enums become strings and tuples become
+ lists) and converts the provided `stats` dictionary into a dictionary of
+ tensors (`_tensor_stats`) on the specified device.
+ """
+ # Track if stats were explicitly provided (not None and not empty)
+ self._stats_explicitly_provided = self.stats is not None and bool(self.stats)
+ # Robust JSON deserialization handling (guard empty maps).
+ if self.features:
+ first_val = next(iter(self.features.values()))
+ if isinstance(first_val, dict):
+ reconstructed = {}
+ for key, ft_dict in self.features.items():
+ reconstructed[key] = PolicyFeature(
+ type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
+ )
+ self.features = reconstructed
+
+ # if keys are strings (JSON), rebuild enum map
+ if self.norm_map and all(isinstance(k, str) for k in self.norm_map):
+ reconstructed = {}
+ for ft_type_str, norm_mode_str in self.norm_map.items():
+ reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
+ self.norm_map = reconstructed
+
+ # Convert stats to tensors and move to the target device once during initialization.
+ self.stats = self.stats or {}
+ if self.dtype is None:
+ self.dtype = torch.float32
+ self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
+
+ def to(
+ self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
+ ) -> _NormalizationMixin:
+ """
+ Moves the processor's normalization stats to the specified device.
+
+ Args:
+ device: The target PyTorch device.
+
+ Returns:
+ The instance of the class, allowing for method chaining.
+ """
+ if device is not None:
+ self.device = device
+ if dtype is not None:
+ self.dtype = dtype
+ self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
+ return self
+
+ def state_dict(self) -> dict[str, Tensor]:
+ """
+ Returns the normalization statistics as a flat state dictionary.
+
+ All tensors are moved to the CPU before being returned, which is standard practice
+ for saving state dictionaries.
+
+ Returns:
+ A flat dictionary mapping from `'feature_name.stat_name'` to the
+ corresponding statistics tensor on the CPU.
+ """
+ flat: dict[str, Tensor] = {}
+ for key, sub in self._tensor_stats.items():
+ for stat_name, tensor in sub.items():
+ flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
+ return flat
+
+ def load_state_dict(self, state: dict[str, Tensor]) -> None:
+ """
+ Loads normalization statistics from a state dictionary.
+
+ The loaded tensors are moved to the processor's configured device.
+
+ **Stats Override Preservation:**
+ If stats were explicitly provided during construction (e.g., via overrides in
+ `DataProcessorPipeline.from_pretrained()`), they are preserved and the state
+ dictionary is ignored. This allows users to override normalization statistics
+ while still loading the rest of the model state.
+
+ This behavior is crucial for scenarios where users want to adapt a pretrained
+ model to a new dataset with different statistics without retraining the entire
+ model.
+
+ Args:
+ state: A flat state dictionary with keys in the format
+ `'feature_name.stat_name'`.
+
+ Note:
+ When stats are preserved due to explicit provision, only the tensor
+ representation is updated to ensure consistency with the current device
+ and dtype settings.
+ """
+ # If stats were explicitly provided during construction, preserve them
+ if self._stats_explicitly_provided and self.stats is not None:
+ # Don't load from state_dict, keep the explicitly provided stats
+ # But ensure _tensor_stats is properly initialized
+ self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
+ return
+
+ # Normal behavior: load stats from state_dict
+ self._tensor_stats.clear()
+ for flat_key, tensor in state.items():
+ key, stat_name = flat_key.rsplit(".", 1)
+ # Load to the processor's configured device.
+ self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
+ dtype=torch.float32, device=self.device
+ )
+
+ # Reconstruct the original stats dict from tensor stats for compatibility with to() method
+ # and other functions that rely on self.stats
+ self.stats = {}
+ for key, tensor_dict in self._tensor_stats.items():
+ self.stats[key] = {}
+ for stat_name, tensor in tensor_dict.items():
+ # Convert tensor back to python/numpy format
+ self.stats[key][stat_name] = from_tensor_to_numpy(tensor)
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns a serializable dictionary of the processor's configuration.
+
+ This method is used when saving the processor to disk, ensuring that its
+ configuration can be reconstructed later.
+
+ Returns:
+ A JSON-serializable dictionary containing the configuration.
+ """
+ config = {
+ "eps": self.eps,
+ "features": {
+ key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
+ },
+ "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
+ }
+ if self.normalize_observation_keys is not None:
+ config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
+ return config
+
+ def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
+ """
+ Applies (un)normalization to all relevant features in an observation dictionary.
+
+ Args:
+ observation: The observation dictionary to process.
+ inverse: If `True`, applies unnormalization; otherwise, applies normalization.
+
+ Returns:
+ A new observation dictionary with the transformed tensor values.
+ """
+ new_observation = dict(observation)
+ for key, feature in self.features.items():
+ if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
+ continue
+ if feature.type != FeatureType.ACTION and key in new_observation:
+ # Convert to tensor but preserve original dtype for adaptation logic
+ tensor = torch.as_tensor(new_observation[key])
+ new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
+ return new_observation
+
+ def _normalize_action(self, action: Tensor, inverse: bool) -> Tensor:
+ # Convert to tensor but preserve original dtype for adaptation logic
+ """
+ Applies (un)normalization to an action tensor.
+
+ Args:
+ action: The action tensor to process.
+ inverse: If `True`, applies unnormalization; otherwise, applies normalization.
+
+ Returns:
+ The transformed action tensor.
+ """
+ processed_action = self._apply_transform(action, ACTION, FeatureType.ACTION, inverse=inverse)
+ return processed_action
+
+ def _apply_transform(
+ self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
+ ) -> Tensor:
+ """
+ Core logic to apply a normalization or unnormalization transformation to a tensor.
+
+ This method selects the appropriate normalization mode based on the feature type
+ and applies the corresponding mathematical operation.
+
+ Normalization Modes:
+ - MEAN_STD: Centers data around zero with unit variance.
+ - MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
+ - QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99).
+ - QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90).
+
+ Args:
+ tensor: The input tensor to transform.
+ key: The feature key corresponding to the tensor.
+ feature_type: The `FeatureType` of the tensor.
+ inverse: If `True`, applies the inverse transformation (unnormalization).
+
+ Returns:
+ The transformed tensor.
+
+ Raises:
+ ValueError: If an unsupported normalization mode is encountered.
+ """
+ norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
+ if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
+ return tensor
+
+ if norm_mode not in (
+ NormalizationMode.MEAN_STD,
+ NormalizationMode.MIN_MAX,
+ NormalizationMode.QUANTILES,
+ NormalizationMode.QUANTILE10,
+ ):
+ raise ValueError(f"Unsupported normalization mode: {norm_mode}")
+
+ # For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor
+ if self._tensor_stats and key in self._tensor_stats:
+ first_stat = next(iter(self._tensor_stats[key].values()))
+ if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
+ self.to(device=tensor.device, dtype=tensor.dtype)
+
+ stats = self._tensor_stats[key]
+
+ if norm_mode == NormalizationMode.MEAN_STD:
+ mean = stats.get("mean", None)
+ std = stats.get("std", None)
+ if mean is None or std is None:
+ raise ValueError(
+ "MEAN_STD normalization mode requires mean and std stats, please update the dataset with the correct stats"
+ )
+
+ mean, std = stats["mean"], stats["std"]
+ # Avoid division by zero by adding a small epsilon.
+ denom = std + self.eps
+ if inverse:
+ return tensor * std + mean
+ return (tensor - mean) / denom
+
+ if norm_mode == NormalizationMode.MIN_MAX:
+ min_val = stats.get("min", None)
+ max_val = stats.get("max", None)
+ if min_val is None or max_val is None:
+ raise ValueError(
+ "MIN_MAX normalization mode requires min and max stats, please update the dataset with the correct stats"
+ )
+
+ min_val, max_val = stats["min"], stats["max"]
+ denom = max_val - min_val
+ # When min_val == max_val, substitute the denominator with a small epsilon
+ # to prevent division by zero. This consistently maps an input equal to
+ # min_val to -1, ensuring a stable transformation.
+ denom = torch.where(
+ denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
+ )
+ if inverse:
+ # Map from [-1, 1] back to [min, max]
+ return (tensor + 1) / 2 * denom + min_val
+ # Map from [min, max] to [-1, 1]
+ return 2 * (tensor - min_val) / denom - 1
+
+ if norm_mode == NormalizationMode.QUANTILES:
+ q01 = stats.get("q01", None)
+ q99 = stats.get("q99", None)
+ if q01 is None or q99 is None:
+ raise ValueError(
+ "QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
+ )
+
+ denom = q99 - q01
+ # Avoid division by zero by adding epsilon when quantiles are identical
+ denom = torch.where(
+ denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
+ )
+ if inverse:
+ return (tensor + 1.0) * denom / 2.0 + q01
+ return 2.0 * (tensor - q01) / denom - 1.0
+
+ if norm_mode == NormalizationMode.QUANTILE10:
+ q10 = stats.get("q10", None)
+ q90 = stats.get("q90", None)
+ if q10 is None or q90 is None:
+ raise ValueError(
+ "QUANTILE10 normalization mode requires q10 and q90 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
+ )
+
+ denom = q90 - q10
+ # Avoid division by zero by adding epsilon when quantiles are identical
+ denom = torch.where(
+ denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
+ )
+ if inverse:
+ return (tensor + 1.0) * denom / 2.0 + q10
+ return 2.0 * (tensor - q10) / denom - 1.0
+
+ # If necessary stats are missing, return input unchanged.
+ return tensor
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="normalizer_processor")
+class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
+ """
+ A processor step that applies normalization to observations and actions in a transition.
+
+ This class uses the logic from `_NormalizationMixin` to perform forward normalization
+ (e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]).
+ It is typically used in the pre-processing pipeline before feeding data to a policy.
+ """
+
+ @classmethod
+ def from_lerobot_dataset(
+ cls,
+ dataset: LeRobotDataset,
+ features: dict[str, PolicyFeature],
+ norm_map: dict[FeatureType, NormalizationMode],
+ *,
+ normalize_observation_keys: set[str] | None = None,
+ eps: float = 1e-8,
+ device: torch.device | str | None = None,
+ ) -> NormalizerProcessorStep:
+ """
+ Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`.
+
+ Args:
+ dataset: The dataset from which to extract normalization statistics.
+ features: The feature definition for the processor.
+ norm_map: The mapping from feature types to normalization modes.
+ normalize_observation_keys: An optional set of observation keys to normalize.
+ eps: A small epsilon value for numerical stability.
+ device: The target device for the processor.
+
+ Returns:
+ A new instance of `NormalizerProcessorStep`.
+ """
+ return cls(
+ features=features,
+ norm_map=norm_map,
+ stats=dataset.meta.stats,
+ normalize_observation_keys=normalize_observation_keys,
+ eps=eps,
+ device=device,
+ )
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ new_transition = transition.copy()
+
+ # Handle observation normalization.
+ observation = new_transition.get(TransitionKey.OBSERVATION)
+ if observation is not None:
+ new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(
+ observation, inverse=False
+ )
+
+ # Handle action normalization.
+ action = new_transition.get(TransitionKey.ACTION)
+
+ if action is None:
+ return new_transition
+
+ if not isinstance(action, PolicyAction):
+ raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
+
+ new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
+
+ return new_transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="unnormalizer_processor")
+class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
+ """
+ A processor step that applies unnormalization to observations and actions.
+
+ This class inverts the normalization process, scaling data back to its original
+ range. It is typically used in the post-processing pipeline to convert a policy's
+ normalized action output into a format that can be executed by a robot or
+ environment.
+ """
+
+ @classmethod
+ def from_lerobot_dataset(
+ cls,
+ dataset: LeRobotDataset,
+ features: dict[str, PolicyFeature],
+ norm_map: dict[FeatureType, NormalizationMode],
+ *,
+ device: torch.device | str | None = None,
+ ) -> UnnormalizerProcessorStep:
+ """
+ Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`.
+
+ Args:
+ dataset: The dataset from which to extract normalization statistics.
+ features: The feature definition for the processor.
+ norm_map: The mapping from feature types to normalization modes.
+ device: The target device for the processor.
+
+ Returns:
+ A new instance of `UnnormalizerProcessorStep`.
+ """
+ return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ new_transition = transition.copy()
+
+ # Handle observation unnormalization.
+ observation = new_transition.get(TransitionKey.OBSERVATION)
+ if observation is not None:
+ new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True)
+
+ # Handle action unnormalization.
+ action = new_transition.get(TransitionKey.ACTION)
+
+ if action is None:
+ return new_transition
+ if not isinstance(action, PolicyAction):
+ raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
+
+ new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
+
+ return new_transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+def hotswap_stats(
+ policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
+) -> PolicyProcessorPipeline:
+ """
+ Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance.
+
+ This function creates a deep copy of the provided pipeline and updates the
+ statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it
+ contains. This is useful for adapting a trained policy to a new environment or
+ dataset with different data distributions without having to reconstruct the entire
+ pipeline.
+
+ Args:
+ policy_processor: The policy processor pipeline to modify.
+ stats: The new dictionary of normalization statistics to apply.
+
+ Returns:
+ A new `PolicyProcessorPipeline` instance with the updated statistics.
+ """
+ rp = deepcopy(policy_processor)
+ for step in rp.steps:
+ if isinstance(step, _NormalizationMixin):
+ step.stats = stats
+ # Re-initialize tensor_stats on the correct device.
+ step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment]
+ return rp
diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py
new file mode 100644
index 000000000..d22d8fb96
--- /dev/null
+++ b/src/lerobot/processor/observation_processor.py
@@ -0,0 +1,206 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+
+import einops
+import numpy as np
+import torch
+from torch import Tensor
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
+
+from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="observation_processor")
+class VanillaObservationProcessorStep(ObservationProcessorStep):
+ """
+ Processes standard Gymnasium observations into the LeRobot format.
+
+ This step handles both image and state data from a typical observation dictionary,
+ preparing it for use in a LeRobot policy.
+
+ **Image Processing:**
+ - Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W),
+ `float32` tensors.
+ - Normalizes pixel values from the [0, 255] range to [0, 1].
+ - Adds a batch dimension if one is not already present.
+ - Recognizes a single image under the key `"pixels"` and maps it to
+ `"observation.image"`.
+ - Recognizes a dictionary of images under the key `"pixels"` and maps them
+ to `"observation.images.{camera_name}"`.
+
+ **State Processing:**
+ - Maps the `"environment_state"` key to `"observation.environment_state"`.
+ - Maps the `"agent_pos"` key to `"observation.state"`.
+ - Converts NumPy arrays to PyTorch tensors.
+ - Adds a batch dimension if one is not already present.
+ """
+
+ def _process_single_image(self, img: np.ndarray) -> Tensor:
+ """
+ Processes a single NumPy image array into a channel-first, normalized tensor.
+
+ Args:
+ img: A NumPy array representing the image, expected to be in channel-last
+ (H, W, C) format with a `uint8` dtype.
+
+ Returns:
+ A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with
+ pixel values normalized to the [0, 1] range.
+
+ Raises:
+ ValueError: If the input image does not appear to be in channel-last
+ format or is not of `uint8` dtype.
+ """
+ # Convert to tensor
+ img_tensor = torch.from_numpy(img)
+
+ # Add batch dimension if needed
+ if img_tensor.ndim == 3:
+ img_tensor = img_tensor.unsqueeze(0)
+
+ # Validate image format
+ _, h, w, c = img_tensor.shape
+ if not (c < h and c < w):
+ raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
+
+ if img_tensor.dtype != torch.uint8:
+ raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
+
+ # Convert to channel-first format
+ img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
+
+ # Convert to float32 and normalize to [0, 1]
+ img_tensor = img_tensor.type(torch.float32) / 255.0
+
+ return img_tensor
+
+ def _process_observation(self, observation):
+ """
+ Processes both image and state observations.
+ """
+
+ processed_obs = observation.copy()
+
+ if "pixels" in processed_obs:
+ pixels = processed_obs.pop("pixels")
+
+ if isinstance(pixels, dict):
+ imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
+ else:
+ imgs = {OBS_IMAGE: pixels}
+
+ for imgkey, img in imgs.items():
+ processed_obs[imgkey] = self._process_single_image(img)
+
+ if "environment_state" in processed_obs:
+ env_state_np = processed_obs.pop("environment_state")
+ env_state = torch.from_numpy(env_state_np).float()
+ if env_state.dim() == 1:
+ env_state = env_state.unsqueeze(0)
+ processed_obs[OBS_ENV_STATE] = env_state
+
+ if "agent_pos" in processed_obs:
+ agent_pos_np = processed_obs.pop("agent_pos")
+ agent_pos = torch.from_numpy(agent_pos_np).float()
+ if agent_pos.dim() == 1:
+ agent_pos = agent_pos.unsqueeze(0)
+ processed_obs[OBS_STATE] = agent_pos
+
+ return processed_obs
+
+ def observation(self, observation):
+ return self._process_observation(observation)
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Transforms feature keys from the Gym standard to the LeRobot standard.
+
+ This method standardizes the feature dictionary by renaming keys according
+ to LeRobot's conventions, ensuring that policies can be constructed correctly.
+ It handles various raw key formats, including those with an "observation." prefix.
+
+ **Renaming Rules:**
+ - `pixels` or `observation.pixels` -> `observation.image`
+ - `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}`
+ - `environment_state` or `observation.environment_state` -> `observation.environment_state`
+ - `agent_pos` or `observation.agent_pos` -> `observation.state`
+
+ Args:
+ features: The policy features dictionary with Gym-style keys.
+
+ Returns:
+ The policy features dictionary with standardized LeRobot keys.
+ """
+ # Build a new features mapping keyed by the same FeatureType buckets
+ # We assume callers already placed features in the correct FeatureType.
+ new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features}
+
+ exact_pairs = {
+ "pixels": OBS_IMAGE,
+ "environment_state": OBS_ENV_STATE,
+ "agent_pos": OBS_STATE,
+ }
+
+ prefix_pairs = {
+ "pixels.": f"{OBS_IMAGES}.",
+ }
+
+ # Iterate over all incoming feature buckets and normalize/move each entry
+ for src_ft, bucket in features.items():
+ for key, feat in list(bucket.items()):
+ handled = False
+
+ # Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
+ for old_prefix, new_prefix in prefix_pairs.items():
+ prefixed_old = f"{OBS_STR}.{old_prefix}"
+ if key.startswith(prefixed_old):
+ suffix = key[len(prefixed_old) :]
+ new_key = f"{new_prefix}{suffix}"
+ new_features[src_ft][new_key] = feat
+ handled = True
+ break
+
+ if key.startswith(old_prefix):
+ suffix = key[len(old_prefix) :]
+ new_key = f"{new_prefix}{suffix}"
+ new_features[src_ft][new_key] = feat
+ handled = True
+ break
+
+ if handled:
+ continue
+
+ # Exact-name rules (pixels, environment_state, agent_pos)
+ for old, new in exact_pairs.items():
+ if key == old or key == f"{OBS_STR}.{old}":
+ new_key = new
+ new_features[src_ft][new_key] = feat
+ handled = True
+ break
+
+ if handled:
+ continue
+
+ # Default: keep key in the same source FeatureType bucket
+ new_features[src_ft][key] = feat
+
+ return new_features
diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py
new file mode 100644
index 000000000..e14d8b0b9
--- /dev/null
+++ b/src/lerobot/processor/pipeline.py
@@ -0,0 +1,1716 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This module defines a generic, sequential data processing pipeline framework, primarily designed for
+transforming robotics data (observations, actions, rewards, etc.).
+
+The core components are:
+- ProcessorStep: An abstract base class for a single data transformation operation.
+- ProcessorStepRegistry: A mechanism to register and retrieve ProcessorStep classes by name.
+- DataProcessorPipeline: A class that chains multiple ProcessorStep instances together to form a complete
+ data processing workflow. It integrates with the Hugging Face Hub for easy sharing and versioning of
+ pipelines, including their configuration and state.
+- Specialized abstract ProcessorStep subclasses (e.g., ObservationProcessorStep, ActionProcessorStep)
+ to simplify the creation of steps that target specific parts of a data transition.
+"""
+
+from __future__ import annotations
+
+import importlib
+import json
+import os
+import re
+from abc import ABC, abstractmethod
+from collections.abc import Callable, Iterable, Sequence
+from copy import deepcopy
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
+
+import torch
+from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file, save_file
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.utils.hub import HubMixin
+
+from .converters import batch_to_transition, create_transition, transition_to_batch
+from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
+
+# Generic type variables for pipeline input and output.
+TInput = TypeVar("TInput")
+TOutput = TypeVar("TOutput")
+
+
+class ProcessorStepRegistry:
+ """A registry for ProcessorStep classes to allow instantiation from a string name.
+
+ This class provides a way to map string identifiers to `ProcessorStep` classes,
+ which is useful for deserializing pipelines from configuration files without
+
+ hardcoding class imports.
+ """
+
+ _registry: dict[str, type] = {}
+
+ @classmethod
+ def register(cls, name: str | None = None):
+ """A class decorator to register a ProcessorStep.
+
+ Args:
+ name: The name to register the class under. If None, the class's `__name__` is used.
+
+ Returns:
+ A decorator function that registers the class and returns it.
+
+ Raises:
+ ValueError: If a step with the same name is already registered.
+ """
+
+ def decorator(step_class: type) -> type:
+ """The actual decorator that performs the registration."""
+ registration_name = name if name is not None else step_class.__name__
+
+ if registration_name in cls._registry:
+ raise ValueError(
+ f"Processor step '{registration_name}' is already registered. "
+ f"Use a different name or unregister the existing one first."
+ )
+
+ cls._registry[registration_name] = step_class
+ # Store the registration name on the class for easy lookup during serialization.
+ step_class._registry_name = registration_name
+ return step_class
+
+ return decorator
+
+ @classmethod
+ def get(cls, name: str) -> type:
+ """Retrieves a processor step class from the registry by its name.
+
+ Args:
+ name: The name of the step to retrieve.
+
+ Returns:
+ The processor step class corresponding to the given name.
+
+ Raises:
+ KeyError: If the name is not found in the registry.
+ """
+ if name not in cls._registry:
+ available = list(cls._registry.keys())
+ raise KeyError(
+ f"Processor step '{name}' not found in registry. "
+ f"Available steps: {available}. "
+ f"Make sure the step is registered using @ProcessorStepRegistry.register()"
+ )
+ return cls._registry[name]
+
+ @classmethod
+ def unregister(cls, name: str) -> None:
+ """Removes a processor step from the registry.
+
+ Args:
+ name: The name of the step to unregister.
+ """
+ cls._registry.pop(name, None)
+
+ @classmethod
+ def list(cls) -> list[str]:
+ """Returns a list of all registered processor step names."""
+ return list(cls._registry.keys())
+
+ @classmethod
+ def clear(cls) -> None:
+ """Clears all processor steps from the registry."""
+ cls._registry.clear()
+
+
+class ProcessorStep(ABC):
+ """Abstract base class for a single step in a data processing pipeline.
+
+ Each step must implement the `__call__` method to perform its transformation
+ on a data transition and the `transform_features` method to describe how it
+ alters the shape or type of data features.
+
+ Subclasses can optionally be stateful by implementing `state_dict` and `load_state_dict`.
+ """
+
+ _current_transition: EnvTransition | None = None
+
+ @property
+ def transition(self) -> EnvTransition:
+ """Provides access to the most recent transition being processed.
+
+ This is useful for steps that need to access other parts of the transition
+ data beyond their primary target (e.g., an action processing step that
+ needs to look at the observation).
+
+ Raises:
+ ValueError: If accessed before the step has been called with a transition.
+ """
+ if self._current_transition is None:
+ raise ValueError("Transition is not set. Make sure to call the step with a transition first.")
+ return self._current_transition
+
+ @abstractmethod
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Processes an environment transition.
+
+ This method should contain the core logic of the processing step.
+
+ Args:
+ transition: The input data transition to be processed.
+
+ Returns:
+ The processed transition.
+ """
+ return transition
+
+ def get_config(self) -> dict[str, Any]:
+ """Returns the configuration of the step for serialization.
+
+ Returns:
+ A JSON-serializable dictionary of configuration parameters.
+ """
+ return {}
+
+ def state_dict(self) -> dict[str, torch.Tensor]:
+ """Returns the state of the step (e.g., learned parameters, running means).
+
+ Returns:
+ A dictionary mapping state names to tensors.
+ """
+ return {}
+
+ def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
+ """Loads the step's state from a state dictionary.
+
+ Args:
+ state: A dictionary of state tensors.
+ """
+ return None
+
+ def reset(self) -> None:
+ """Resets the internal state of the processor step, if any."""
+ return None
+
+ @abstractmethod
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """Defines how this step modifies the description of pipeline features.
+
+ This method is used to track changes in data shapes, dtypes, or modalities
+ as data flows through the pipeline, without needing to process actual data.
+
+ Args:
+ features: A dictionary describing the input features for observations, actions, etc.
+
+ Returns:
+ A dictionary describing the output features after this step's transformation.
+ """
+ return features
+
+
+class ProcessorKwargs(TypedDict, total=False):
+ """A TypedDict for optional keyword arguments used in pipeline construction."""
+
+ to_transition: Callable[[dict[str, Any]], EnvTransition] | None
+ to_output: Callable[[EnvTransition], Any] | None
+ name: str | None
+ before_step_hooks: list[Callable[[int, EnvTransition], None]] | None
+ after_step_hooks: list[Callable[[int, EnvTransition], None]] | None
+
+
+class ProcessorMigrationError(Exception):
+ """Raised when a model needs migration to the processor format"""
+
+ def __init__(self, model_path: str | Path, migration_command: str, original_error: str):
+ self.model_path = model_path
+ self.migration_command = migration_command
+ self.original_error = original_error
+ super().__init__(
+ f"Model '{model_path}' requires migration to processor format. "
+ f"Run: {migration_command}\n\nOriginal error: {original_error}"
+ )
+
+
+@dataclass
+class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
+ """A sequential pipeline for processing data, integrated with the Hugging Face Hub.
+
+ This class chains together multiple `ProcessorStep` instances to form a complete
+ data processing workflow. It's generic, allowing for custom input and output types,
+ which are handled by the `to_transition` and `to_output` converters.
+
+ Attributes:
+ steps: A sequence of `ProcessorStep` objects that make up the pipeline.
+ name: A descriptive name for the pipeline.
+ to_transition: A function to convert raw input data into the standardized `EnvTransition` format.
+ to_output: A function to convert the final `EnvTransition` into the desired output format.
+ before_step_hooks: A list of functions to be called before each step is executed.
+ after_step_hooks: A list of functions to be called after each step is executed.
+ """
+
+ steps: Sequence[ProcessorStep] = field(default_factory=list)
+ name: str = "DataProcessorPipeline"
+
+ to_transition: Callable[[TInput], EnvTransition] = field(
+ default_factory=lambda: cast(Callable[[TInput], EnvTransition], batch_to_transition), repr=False
+ )
+ to_output: Callable[[EnvTransition], TOutput] = field(
+ default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch),
+ repr=False,
+ )
+
+ before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
+ after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
+
+ def __call__(self, data: TInput) -> TOutput:
+ """Processes input data through the full pipeline.
+
+ Args:
+ data: The input data to process.
+
+ Returns:
+ The processed data in the specified output format.
+ """
+ transition = self.to_transition(data)
+ transformed_transition = self._forward(transition)
+ return self.to_output(transformed_transition)
+
+ def _forward(self, transition: EnvTransition) -> EnvTransition:
+ """Executes all processing steps and hooks in sequence.
+
+ Args:
+ transition: The initial `EnvTransition` object.
+
+ Returns:
+ The final `EnvTransition` after all steps have been applied.
+ """
+ for idx, processor_step in enumerate(self.steps):
+ # Execute pre-hooks
+ for hook in self.before_step_hooks:
+ hook(idx, transition)
+
+ transition = processor_step(transition)
+
+ # Execute post-hooks
+ for hook in self.after_step_hooks:
+ hook(idx, transition)
+ return transition
+
+ def step_through(self, data: TInput) -> Iterable[EnvTransition]:
+ """Processes data step-by-step, yielding the transition at each stage.
+
+ This is a generator method useful for debugging and inspecting the intermediate
+ state of the data as it passes through the pipeline.
+
+ Args:
+ data: The input data.
+
+ Yields:
+ The `EnvTransition` object, starting with the initial state and then after
+ each processing step.
+ """
+ transition = self.to_transition(data)
+
+ # Yield the initial state before any processing.
+ yield transition
+
+ for processor_step in self.steps:
+ transition = processor_step(transition)
+ yield transition
+
+ def _save_pretrained(self, save_directory: Path, **kwargs):
+ """Internal method to comply with `HubMixin`'s saving mechanism.
+
+ This method does the actual saving work and is called by HubMixin.save_pretrained.
+ """
+ config_filename = kwargs.pop("config_filename", None)
+
+ # Sanitize the pipeline name to create a valid filename prefix.
+ sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
+
+ if config_filename is None:
+ config_filename = f"{sanitized_name}.json"
+
+ config: dict[str, Any] = {
+ "name": self.name,
+ "steps": [],
+ }
+
+ # Iterate through each step to build its configuration entry.
+ for step_index, processor_step in enumerate(self.steps):
+ registry_name = getattr(processor_step.__class__, "_registry_name", None)
+
+ step_entry: dict[str, Any] = {}
+ # Prefer registry name for portability, otherwise fall back to full class path.
+ if registry_name:
+ step_entry["registry_name"] = registry_name
+ else:
+ step_entry["class"] = (
+ f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
+ )
+
+ # Save step configuration if `get_config` is implemented.
+ if hasattr(processor_step, "get_config"):
+ step_entry["config"] = processor_step.get_config()
+
+ # Save step state if `state_dict` is implemented and returns a non-empty dict.
+ if hasattr(processor_step, "state_dict"):
+ state = processor_step.state_dict()
+ if state:
+ # Clone tensors to avoid modifying the original state.
+ cloned_state = {key: tensor.clone() for key, tensor in state.items()}
+
+ # Create a unique filename for the state file.
+ if registry_name:
+ state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
+ else:
+ state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
+
+ save_file(cloned_state, os.path.join(str(save_directory), state_filename))
+ step_entry["state_file"] = state_filename
+
+ config["steps"].append(step_entry)
+
+ # Write the main configuration JSON file.
+ with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
+ json.dump(config, file_pointer, indent=2)
+
+ def save_pretrained(
+ self,
+ save_directory: str | Path | None = None,
+ *,
+ repo_id: str | None = None,
+ push_to_hub: bool = False,
+ card_kwargs: dict[str, Any] | None = None,
+ config_filename: str | None = None,
+ **push_to_hub_kwargs,
+ ):
+ """Saves the pipeline's configuration and state to a directory.
+
+ This method creates a JSON configuration file that defines the pipeline's structure
+ (name and steps). For each stateful step, it also saves a `.safetensors` file
+ containing its state dictionary.
+
+ Args:
+ save_directory: The directory where the pipeline will be saved. If None, saves to
+ HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
+ repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
+ push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
+ card_kwargs: Additional arguments passed to the card template to customize the card.
+ config_filename: The name of the JSON configuration file. If None, a name is
+ generated from the pipeline's `name` attribute.
+ **push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method.
+ """
+ if save_directory is None:
+ # Use default directory in HF_LEROBOT_HOME
+ from lerobot.utils.constants import HF_LEROBOT_HOME
+
+ sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
+ save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name
+
+ # For direct saves (not through hub), handle config_filename
+ if not push_to_hub and config_filename is not None:
+ # Call _save_pretrained directly with config_filename
+ save_directory = Path(save_directory)
+ save_directory.mkdir(parents=True, exist_ok=True)
+ self._save_pretrained(save_directory, config_filename=config_filename)
+ return None
+
+ # Pass config_filename through kwargs for _save_pretrained when using hub
+ if config_filename is not None:
+ push_to_hub_kwargs["config_filename"] = config_filename
+
+ # Call parent's save_pretrained which will call our _save_pretrained
+ return super().save_pretrained(
+ save_directory=save_directory,
+ repo_id=repo_id,
+ push_to_hub=push_to_hub,
+ card_kwargs=card_kwargs,
+ **push_to_hub_kwargs,
+ )
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str | Path,
+ config_filename: str,
+ *,
+ force_download: bool = False,
+ resume_download: bool | None = None,
+ proxies: dict[str, str] | None = None,
+ token: str | bool | None = None,
+ cache_dir: str | Path | None = None,
+ local_files_only: bool = False,
+ revision: str | None = None,
+ overrides: dict[str, Any] | None = None,
+ to_transition: Callable[[TInput], EnvTransition] | None = None,
+ to_output: Callable[[EnvTransition], TOutput] | None = None,
+ **kwargs,
+ ) -> DataProcessorPipeline[TInput, TOutput]:
+ """Loads a pipeline from a local directory, single file, or Hugging Face Hub repository.
+
+ This method implements a simplified loading pipeline with intelligent migration detection:
+
+ **Simplified Loading Strategy**:
+ 1. **Config Loading** (_load_config):
+ - **Directory**: Load specified config_filename from directory
+ - **Single file**: Load file directly (config_filename ignored)
+ - **Hub repository**: Download specified config_filename from Hub
+
+ 2. **Config Validation** (_validate_loaded_config):
+ - Format validation: Ensure config is valid processor format
+ - Migration detection: Guide users to migrate old LeRobot models
+ - Clear errors: Provide actionable error messages
+
+ 3. **Step Construction** (_build_steps_with_overrides):
+ - Class resolution: Registry lookup or dynamic imports
+ - Override merging: User parameters override saved config
+ - State loading: Load .safetensors files for stateful steps
+
+ 4. **Override Validation** (_validate_overrides_used):
+ - Ensure all user overrides were applied (catch typos)
+ - Provide helpful error messages with available keys
+
+ **Migration Detection**:
+ - **Smart detection**: Analyzes JSON files to detect old LeRobot models
+ - **Precise targeting**: Avoids false positives on other HuggingFace models
+ - **Clear guidance**: Provides exact migration command to run
+ - **Error mode**: Always raises ProcessorMigrationError for clear user action
+
+ **Loading Examples**:
+ ```python
+ # Directory loading
+ pipeline = DataProcessorPipeline.from_pretrained("/models/my_model", config_filename="processor.json")
+
+ # Single file loading
+ pipeline = DataProcessorPipeline.from_pretrained(
+ "/models/my_model/processor.json", config_filename="processor.json"
+ )
+
+ # Hub loading
+ pipeline = DataProcessorPipeline.from_pretrained("user/repo", config_filename="processor.json")
+
+ # Multiple configs (preprocessor/postprocessor)
+ preprocessor = DataProcessorPipeline.from_pretrained(
+ "model", config_filename="policy_preprocessor.json"
+ )
+ postprocessor = DataProcessorPipeline.from_pretrained(
+ "model", config_filename="policy_postprocessor.json"
+ )
+ ```
+
+ **Override System**:
+ - **Key matching**: Use registry names or class names as override keys
+ - **Config merging**: User overrides take precedence over saved config
+ - **Validation**: Ensure all override keys match actual steps (catch typos)
+ - **Example**: overrides={"NormalizeStep": {"device": "cuda"}}
+
+ Args:
+ pretrained_model_name_or_path: The identifier of the repository on the Hugging Face Hub,
+ a path to a local directory, or a path to a single config file.
+ config_filename: The name of the pipeline's JSON configuration file. Always required
+ to prevent ambiguity when multiple configs exist (e.g., preprocessor vs postprocessor).
+ force_download: Whether to force (re)downloading the files.
+ resume_download: Whether to resume a previously interrupted download.
+ proxies: A dictionary of proxy servers to use.
+ token: The token to use as HTTP bearer authorization for private Hub repositories.
+ cache_dir: The path to a specific cache folder to store downloaded files.
+ local_files_only: If True, avoid downloading files from the Hub.
+ revision: The specific model version to use (e.g., a branch name, tag name, or commit id).
+ overrides: A dictionary to override the configuration of specific steps. Keys should
+ match the step's class name or registry name.
+ to_transition: A custom function to convert input data to `EnvTransition`.
+ to_output: A custom function to convert the final `EnvTransition` to the output format.
+ **kwargs: Additional arguments (not used).
+
+ Returns:
+ An instance of `DataProcessorPipeline` loaded with the specified configuration and state.
+
+ Raises:
+ FileNotFoundError: If the config file cannot be found.
+ ValueError: If configuration is ambiguous or instantiation fails.
+ ImportError: If a step's class cannot be imported.
+ KeyError: If an override key doesn't match any step in the pipeline.
+ ProcessorMigrationError: If the model requires migration to processor format.
+ """
+ model_id = str(pretrained_model_name_or_path)
+ hub_download_kwargs = {
+ "force_download": force_download,
+ "resume_download": resume_download,
+ "proxies": proxies,
+ "token": token,
+ "cache_dir": cache_dir,
+ "local_files_only": local_files_only,
+ "revision": revision,
+ }
+
+ # 1. Load configuration using simplified 3-way logic
+ loaded_config, base_path = cls._load_config(model_id, config_filename, hub_download_kwargs)
+
+ # 2. Validate configuration and handle migration
+ cls._validate_loaded_config(model_id, loaded_config, config_filename)
+
+ # 3. Build steps with overrides
+ steps, validated_overrides = cls._build_steps_with_overrides(
+ loaded_config, overrides or {}, model_id, base_path, hub_download_kwargs
+ )
+
+ # 4. Validate that all overrides were used
+ cls._validate_overrides_used(validated_overrides, loaded_config)
+
+ # 5. Construct and return the final pipeline instance
+ return cls(
+ steps=steps,
+ name=loaded_config.get("name", "DataProcessorPipeline"),
+ to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
+ to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
+ )
+
+ @classmethod
+ def _load_config(
+ cls,
+ model_id: str,
+ config_filename: str,
+ hub_download_kwargs: dict[str, Any],
+ ) -> tuple[dict[str, Any], Path]:
+ """Load configuration from local file or Hugging Face Hub.
+
+ This method implements a super-simplified 3-way loading strategy:
+
+ 1. **Local directory**: Load config_filename from directory
+ - Example: model_id="/models/my_model", config_filename="processor.json"
+ - Loads: "/models/my_model/processor.json"
+
+ 2. **Single file**: Load file directly (ignore config_filename)
+ - Example: model_id="/models/my_model/processor.json"
+ - Loads: "/models/my_model/processor.json" (config_filename ignored)
+
+ 3. **Hub repository**: Download config_filename from Hub
+ - Example: model_id="user/repo", config_filename="processor.json"
+ - Downloads and loads: config_filename from Hub repo
+
+ **Benefits of Explicit config_filename**:
+ - No auto-detection complexity or edge cases
+ - No risk of loading wrong config (preprocessor vs postprocessor)
+ - Consistent behavior across local and Hub usage
+ - Clear, predictable errors
+
+ Args:
+ model_id: The model identifier (Hub repo ID, local directory, or file path)
+ config_filename: The explicit config filename to load (always required)
+ hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.)
+
+ Returns:
+ Tuple of (loaded_config, base_path)
+ - loaded_config: Parsed JSON config dict (always loaded, never None)
+ - base_path: Directory containing config file (for state file resolution)
+
+ Raises:
+ FileNotFoundError: If config file cannot be found locally or on Hub
+ """
+ model_path = Path(model_id)
+
+ if model_path.is_dir():
+ # Directory: load specified config from directory
+ config_path = model_path / config_filename
+ if not config_path.exists():
+ # Check for migration before giving clear error
+ if cls._should_suggest_migration(model_path):
+ cls._suggest_processor_migration(model_id, f"Config file '{config_filename}' not found")
+ raise FileNotFoundError(
+ f"Config file '{config_filename}' not found in directory '{model_id}'"
+ )
+
+ with open(config_path) as f:
+ return json.load(f), model_path
+
+ elif model_path.is_file():
+ # File: load file directly (config_filename is ignored for single files)
+ with open(model_path) as f:
+ return json.load(f), model_path.parent
+
+ else:
+ # Hub: download specified config
+ try:
+ config_path = hf_hub_download(
+ repo_id=model_id,
+ filename=config_filename,
+ repo_type="model",
+ **hub_download_kwargs,
+ )
+
+ with open(config_path) as f:
+ return json.load(f), Path(config_path).parent
+
+ except Exception as e:
+ raise FileNotFoundError(
+ f"Could not find '{config_filename}' on the HuggingFace Hub at '{model_id}'"
+ ) from e
+
+ @classmethod
+ def _validate_loaded_config(
+ cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
+ ) -> None:
+ """Validate that a config was loaded and is a valid processor config.
+
+ This method validates processor config format with intelligent migration detection:
+
+ **Config Format Validation**:
+ - Use _is_processor_config() to validate structure
+ - Must have "steps" field with list of step configurations
+ - Each step needs "class" or "registry_name"
+ - If validation fails AND local directory: Check for migration need
+ - If migration needed: Raise ProcessorMigrationError with command
+ - If no migration: Raise ValueError with helpful error message
+
+ **Migration Detection Logic**:
+ - Only triggered for local directories (not Hub repos)
+ - Analyzes all JSON files in directory to detect old LeRobot models
+ - Provides exact migration command with model path
+
+ Args:
+ model_id: The model identifier (used for migration detection)
+ loaded_config: The loaded config dictionary (guaranteed non-None)
+ config_filename: The config filename that was loaded (for error messages)
+
+ Raises:
+ ValueError: If config format is invalid
+ ProcessorMigrationError: If model needs migration to processor format
+ """
+ # Validate that this is actually a processor config
+ if not cls._is_processor_config(loaded_config):
+ if Path(model_id).is_dir() and cls._should_suggest_migration(Path(model_id)):
+ cls._suggest_processor_migration(
+ model_id,
+ f"Config file '{config_filename}' is not a valid processor configuration",
+ )
+ raise ValueError(
+ f"Config file '{config_filename}' is not a valid processor configuration. "
+ f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
+ )
+
+ @classmethod
+ def _build_steps_with_overrides(
+ cls,
+ loaded_config: dict[str, Any],
+ overrides: dict[str, Any],
+ model_id: str,
+ base_path: Path | None,
+ hub_download_kwargs: dict[str, Any],
+ ) -> tuple[list[ProcessorStep], set[str]]:
+ """Build all processor steps with overrides and state loading.
+
+ This method orchestrates the complete step construction pipeline:
+
+ **For each step in loaded_config["steps"]**:
+
+ 1. **Class Resolution** (via _resolve_step_class):
+ - **If "registry_name" exists**: Look up in ProcessorStepRegistry
+ Example: {"registry_name": "normalize_step"} -> Get registered class
+ - **Else use "class" field**: Dynamic import from full module path
+ Example: {"class": "lerobot.processor.normalize.NormalizeStep"}
+ - **Result**: (step_class, step_key) where step_key is used for overrides
+
+ 2. **Step Instantiation** (via _instantiate_step):
+ - **Merge configs**: saved_config + user_overrides
+ - **Override priority**: User overrides take precedence over saved config
+ - **Example**: saved={"mean": 0.0}, override={"mean": 1.0} -> final={"mean": 1.0}
+ - **Result**: Instantiated ProcessorStep object
+
+ 3. **State Loading** (via _load_step_state):
+ - **If step has "state_file"**: Load tensor state from .safetensors
+ - **Local first**: Check base_path/state_file.safetensors
+ - **Hub fallback**: Download state file if not found locally
+ - **Optional**: Only load if step has load_state_dict method
+
+ 4. **Override Tracking**:
+ - **Track used overrides**: Remove step_key from remaining set
+ - **Purpose**: Validate all user overrides were applied (detect typos)
+
+ **Error Handling**:
+ - Class resolution errors -> ImportError with helpful message
+ - Instantiation errors -> ValueError with config details
+ - State loading errors -> Propagated from load_state_dict
+
+ Args:
+ loaded_config: The loaded processor configuration (must have "steps" field)
+ overrides: User-provided parameter overrides (keyed by class/registry name)
+ model_id: The model identifier (needed for Hub state file downloads)
+ base_path: Local directory path for finding state files
+ hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.)
+
+ Returns:
+ Tuple of (instantiated_steps_list, unused_override_keys)
+ - instantiated_steps_list: List of ready-to-use ProcessorStep instances
+ - unused_override_keys: Override keys that didn't match any step (for validation)
+
+ Raises:
+ ImportError: If a step class cannot be imported or found in registry
+ ValueError: If a step cannot be instantiated with its configuration
+ """
+ steps: list[ProcessorStep] = []
+ override_keys = set(overrides.keys())
+
+ for step_entry in loaded_config["steps"]:
+ # 1. Get step class and key
+ step_class, step_key = cls._resolve_step_class(step_entry)
+
+ # 2. Instantiate step with overrides
+ step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
+
+ # 3. Load step state if available
+ cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
+
+ # 4. Track used overrides
+ if step_key in override_keys:
+ override_keys.discard(step_key)
+
+ steps.append(step_instance)
+
+ return steps, override_keys
+
+ @classmethod
+ def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
+ """Resolve step class from registry or import path.
+
+ This method implements a two-tier resolution strategy:
+
+ **Tier 1: Registry-based resolution** (preferred):
+ - **If "registry_name" in step_entry**: Look up in ProcessorStepRegistry
+ - **Advantage**: Faster, no imports needed, guaranteed compatibility
+ - **Example**: {"registry_name": "normalize_step"} -> Get pre-registered class
+ - **Error**: KeyError if registry_name not found -> Convert to ImportError
+
+ **Tier 2: Dynamic import fallback**:
+ - **Else use "class" field**: Full module.ClassName import path
+ - **Process**: Split "module.path.ClassName" into module + class parts
+ - **Import**: Use importlib.import_module() + getattr()
+ - **Example**: "lerobot.processor.normalize.NormalizeStep"
+ a. Import module: "lerobot.processor.normalize"
+ b. Get class: getattr(module, "NormalizeStep")
+ - **step_key**: Use class_name ("NormalizeStep") for overrides
+
+ **Override Key Strategy**:
+ - Registry steps: Use registry_name ("normalize_step")
+ - Import steps: Use class_name ("NormalizeStep")
+ - This allows users to override with: {"normalize_step": {...}} or {"NormalizeStep": {...}}
+
+ **Error Handling**:
+ - Registry KeyError -> ImportError with registry context
+ - Import/Attribute errors -> ImportError with helpful suggestions
+ - All errors include troubleshooting guidance
+
+ Args:
+ step_entry: The step configuration dictionary (must have "registry_name" or "class")
+
+ Returns:
+ Tuple of (step_class, step_key)
+ - step_class: The resolved ProcessorStep class (ready for instantiation)
+ - step_key: The key used for user overrides (registry_name or class_name)
+
+ Raises:
+ ImportError: If step class cannot be loaded from registry or import path
+ """
+ if "registry_name" in step_entry:
+ try:
+ step_class = ProcessorStepRegistry.get(step_entry["registry_name"])
+ return step_class, step_entry["registry_name"]
+ except KeyError as e:
+ raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e
+ else:
+ # Fallback to dynamic import using the full class path
+ full_class_path = step_entry["class"]
+ module_path, class_name = full_class_path.rsplit(".", 1)
+
+ try:
+ module = importlib.import_module(module_path)
+ step_class = getattr(module, class_name)
+ return step_class, class_name
+ except (ImportError, AttributeError) as e:
+ raise ImportError(
+ f"Failed to load processor step '{full_class_path}'. "
+ f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. "
+ f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. "
+ f"Error: {str(e)}"
+ ) from e
+
+ @classmethod
+ def _instantiate_step(
+ cls,
+ step_entry: dict[str, Any],
+ step_class: type[ProcessorStep],
+ step_key: str,
+ overrides: dict[str, Any],
+ ) -> ProcessorStep:
+ """Instantiate a single processor step with config overrides.
+
+ This method handles the configuration merging and instantiation logic:
+
+ **Configuration Merging Strategy**:
+ 1. **Extract saved config**: Get step_entry.get("config", {}) from saved pipeline
+ - Example: {"config": {"mean": 0.0, "std": 1.0}}
+ 2. **Extract user overrides**: Get overrides.get(step_key, {}) for this step
+ - Example: overrides = {"NormalizeStep": {"mean": 2.0, "device": "cuda"}}
+ 3. **Merge with priority**: {**saved_cfg, **step_overrides}
+ - **Override priority**: User values override saved values
+ - **Result**: {"mean": 2.0, "std": 1.0, "device": "cuda"}
+
+ **Instantiation Process**:
+ - **Call constructor**: step_class(**merged_cfg)
+ - **Example**: NormalizeStep(mean=2.0, std=1.0, device="cuda")
+
+ **Error Handling**:
+ - **Any exception during instantiation**: Convert to ValueError
+ - **Include context**: step name, attempted config, original error
+ - **Purpose**: Help users debug configuration issues
+ - **Common causes**:
+ a. Invalid parameter types (str instead of float)
+ b. Missing required parameters
+ c. Incompatible parameter combinations
+
+ Args:
+ step_entry: The step configuration from saved config (contains "config" dict)
+ step_class: The step class to instantiate (already resolved)
+ step_key: The key used for overrides ("registry_name" or class name)
+ overrides: User-provided parameter overrides (keyed by step_key)
+
+ Returns:
+ The instantiated processor step (ready for use)
+
+ Raises:
+ ValueError: If step cannot be instantiated, with detailed error context
+ """
+ try:
+ saved_cfg = step_entry.get("config", {})
+ step_overrides = overrides.get(step_key, {})
+ merged_cfg = {**saved_cfg, **step_overrides}
+ return step_class(**merged_cfg)
+ except Exception as e:
+ step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown"))
+ raise ValueError(
+ f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. "
+ f"Error: {str(e)}"
+ ) from e
+
+ @classmethod
+ def _load_step_state(
+ cls,
+ step_instance: ProcessorStep,
+ step_entry: dict[str, Any],
+ model_id: str,
+ base_path: Path | None,
+ hub_download_kwargs: dict[str, Any],
+ ) -> None:
+ """Load state dictionary for a processor step if available.
+
+ This method implements conditional state loading with local/Hub fallback:
+
+ **Precondition Checks** (early return if not met):
+ 1. **"state_file" in step_entry**: Step config specifies a state file
+ - **If missing**: Step has no saved state (e.g., stateless transforms)
+ 2. **hasattr(step_instance, "load_state_dict")**: Step supports state loading
+ - **If missing**: Step doesn't implement state loading (rare)
+
+ **State File Resolution Strategy**:
+ 1. **Local file priority**: Check base_path/state_filename exists
+ - **Advantage**: Faster, no network calls
+ - **Example**: "/models/my_model/normalize_step_0.safetensors"
+ - **Use case**: Loading from local saved model directory
+
+ 2. **Hub download fallback**: Download state file from repository
+ - **When triggered**: Local file not found or base_path is None
+ - **Process**: Use hf_hub_download with same parameters as config
+ - **Example**: Download "normalize_step_0.safetensors" from "user/repo"
+ - **Result**: Downloaded to local cache, path returned
+
+ **State Loading Process**:
+ - **Load tensors**: Use safetensors.torch.load_file()
+ - **Apply to step**: Call step_instance.load_state_dict(tensor_dict)
+ - **In-place modification**: Updates step's internal tensor state
+
+ **Common state file examples**:
+ - "normalize_step_0.safetensors" - normalization statistics
+ - "custom_step_1.safetensors" - learned parameters
+ - "tokenizer_step_2.safetensors" - vocabulary embeddings
+
+ Args:
+ step_instance: The step instance to load state into (must have load_state_dict)
+ step_entry: The step configuration dictionary (may contain "state_file")
+ model_id: The model identifier (used for Hub downloads if needed)
+ base_path: Local directory path for finding state files (None for Hub-only)
+ hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.)
+
+ Note:
+ This method modifies step_instance in-place and returns None.
+ If state loading fails, exceptions from load_state_dict propagate.
+ """
+ if "state_file" not in step_entry or not hasattr(step_instance, "load_state_dict"):
+ return
+
+ state_filename = step_entry["state_file"]
+
+ # Try local file first
+ if base_path and (base_path / state_filename).exists():
+ state_path = str(base_path / state_filename)
+ else:
+ # Download from Hub
+ state_path = hf_hub_download(
+ repo_id=model_id,
+ filename=state_filename,
+ repo_type="model",
+ **hub_download_kwargs,
+ )
+
+ step_instance.load_state_dict(load_file(state_path))
+
+ @classmethod
+ def _validate_overrides_used(
+ cls, remaining_override_keys: set[str], loaded_config: dict[str, Any]
+ ) -> None:
+ """Validate that all provided overrides were used.
+
+ This method ensures user overrides are valid to catch typos and configuration errors:
+
+ **Validation Logic**:
+ 1. **If remaining_override_keys is empty**: All overrides were used -> Success
+ - **Early return**: No validation needed
+ - **Normal case**: User provided correct override keys
+
+ 2. **If remaining_override_keys has entries**: Some overrides unused -> Error
+ - **Root cause**: User provided keys that don't match any step
+ - **Common issues**:
+ a. Typos in step names ("NormalizStep" vs "NormalizeStep")
+ b. Using wrong key type (class name vs registry name)
+ c. Step doesn't exist in saved pipeline
+
+ **Helpful Error Generation**:
+ - **Extract available keys**: Build list of valid override keys from config
+ a. **Registry steps**: Use "registry_name" directly
+ b. **Import steps**: Extract class name from "class" field
+ - Example: "lerobot.processor.normalize.NormalizeStep" -> "NormalizeStep"
+ - **Error message includes**:
+ a. Invalid keys provided by user
+ b. List of valid keys they can use
+ c. Guidance about registry vs class names
+
+ **Override Key Resolution Rules**:
+ - Steps with "registry_name": Use registry_name for overrides
+ - Steps with "class": Use final class name for overrides
+ - Users must match these exact keys in their overrides dict
+
+ Args:
+ remaining_override_keys: Override keys that weren't matched to any step
+ loaded_config: The loaded processor configuration (contains "steps" list)
+
+ Raises:
+ KeyError: If any override keys were not used, with helpful error message
+ """
+ if not remaining_override_keys:
+ return
+
+ available_keys = [
+ step.get("registry_name") or step["class"].rsplit(".", 1)[1] for step in loaded_config["steps"]
+ ]
+
+ raise KeyError(
+ f"Override keys {list(remaining_override_keys)} do not match any step in the saved configuration. "
+ f"Available step keys: {available_keys}. "
+ f"Make sure override keys match exact step class names or registry names."
+ )
+
+ @classmethod
+ def _should_suggest_migration(cls, model_path: Path) -> bool:
+ """Check if directory has JSON files but no processor configs.
+
+ This method implements smart migration detection to avoid false positives:
+
+ **Decision Logic**:
+ 1. **No JSON files found**: Return False
+ - **Reason**: Empty directory or only non-config files
+ - **Example**: Directory with only .safetensors, .md files
+ - **Action**: No migration needed
+
+ 2. **JSON files exist**: Analyze each file
+ - **Goal**: Determine if ANY file is a valid processor config
+ - **Process**:
+ a. Try to parse each .json file
+ b. Skip files with JSON parse errors (malformed)
+ c. Check if parsed config passes _is_processor_config()
+ - **If ANY valid processor found**: Return False (no migration)
+ - **If NO valid processors found**: Return True (migration needed)
+
+ **Examples**:
+ - **No migration**: ["processor.json", "config.json"] where processor.json is valid
+ - **Migration needed**: ["config.json", "train.json"] where both are model configs
+ - **No migration**: [] (empty directory)
+ - **Migration needed**: ["old_model_config.json"] with old LeRobot format
+
+ **Why this works**:
+ - **Precise detection**: Only suggests migration for actual old LeRobot models
+ - **Avoids false positives**: Won't trigger on other HuggingFace model types
+ - **Graceful handling**: Ignores malformed JSON files
+
+ Args:
+ model_path: Path to local directory to analyze
+
+ Returns:
+ True if directory has JSON configs but none are processor configs (migration needed)
+ False if no JSON files or at least one valid processor config exists
+ """
+ json_files = list(model_path.glob("*.json"))
+ if len(json_files) == 0:
+ return False
+
+ # Check if any JSON file is a processor config
+ for json_file in json_files:
+ try:
+ with open(json_file) as f:
+ config = json.load(f)
+
+ if cls._is_processor_config(config):
+ return False # Found at least one processor config, no migration needed
+
+ except (json.JSONDecodeError, OSError):
+ # Skip files that can't be parsed as JSON
+ continue
+
+ # Have JSON files but no processor configs - suggest migration
+ return True
+
+ @classmethod
+ def _is_processor_config(cls, config: dict) -> bool:
+ """Check if config follows DataProcessorPipeline format.
+
+ This method validates the processor configuration structure:
+
+ **Required Structure Validation**:
+ 1. **"steps" field existence**: Must have top-level "steps" key
+ - **If missing**: Not a processor config (e.g., model config, train config)
+ - **Example invalid**: {"type": "act", "hidden_dim": 256}
+
+ 2. **"steps" field type**: Must be a list, not other types
+ - **If not list**: Invalid format
+ - **Example invalid**: {"steps": "some_string"} or {"steps": {"key": "value"}}
+
+ 3. **Empty steps validation**: Empty list is valid
+ - **If len(steps) == 0**: Return True immediately
+ - **Use case**: Empty processor pipeline (no-op)
+ - **Example valid**: {"name": "EmptyProcessor", "steps": []}
+
+ **Individual Step Validation** (for non-empty steps):
+ For each step in the steps list:
+ 1. **Step type**: Must be a dictionary
+ - **If not dict**: Invalid step format
+ - **Example invalid**: ["string_step", 123, true]
+
+ 2. **Step identifier**: Must have either "class" OR "registry_name"
+ - **"registry_name"**: Registered step (preferred)
+ Example: {"registry_name": "normalize_step", "config": {...}}
+ - **"class"**: Full import path
+ Example: {"class": "lerobot.processor.normalize.NormalizeStep"}
+ - **If neither**: Invalid step (can't resolve class)
+ - **If both**: Also valid (registry_name takes precedence)
+
+ **Valid Processor Config Examples**:
+ - {"steps": []} - Empty processor
+ - {"steps": [{"registry_name": "normalize"}]} - Registry step
+ - {"steps": [{"class": "my.module.Step"}]} - Import step
+ - {"name": "MyProcessor", "steps": [...]} - With name
+
+ **Invalid Config Examples**:
+ - {"type": "act"} - Missing "steps"
+ - {"steps": "normalize"} - Steps not a list
+ - {"steps": [{}]} - Step missing class/registry_name
+ - {"steps": ["string"]} - Step not a dict
+
+ Args:
+ config: The configuration dictionary to validate
+
+ Returns:
+ True if config follows valid DataProcessorPipeline format, False otherwise
+ """
+ # Must have a "steps" field with a list of step configurations
+ if not isinstance(config.get("steps"), list):
+ return False
+
+ steps = config["steps"]
+ if len(steps) == 0:
+ return True # Empty processor is valid
+
+ # Each step must be a dict with either "class" or "registry_name"
+ for step in steps:
+ if not isinstance(step, dict):
+ return False
+ if not ("class" in step or "registry_name" in step):
+ return False
+
+ return True
+
+ @classmethod
+ def _suggest_processor_migration(cls, model_path: str | Path, original_error: str) -> None:
+ """Raise migration error when we detect JSON files but no processor configs.
+
+ This method is called when migration detection determines that a model
+ directory contains configuration files but none are valid processor configs.
+ This typically indicates an old LeRobot model that needs migration.
+
+ **When this is called**:
+ - User tries to load DataProcessorPipeline from local directory
+ - Directory contains JSON configuration files
+ - None of the JSON files follow processor config format
+ - _should_suggest_migration() returned True
+
+ **Migration Command Generation**:
+ - Constructs exact command user needs to run
+ - Uses the migration script: migrate_policy_normalization.py
+ - Includes the model path automatically
+ - Example: "python src/lerobot/processor/migrate_policy_normalization.py --pretrained-path /models/old_model"
+
+ **Error Structure**:
+ - **Always raises**: ProcessorMigrationError (never returns)
+ - **Includes**: model_path, migration_command, original_error
+ - **Purpose**: Force user attention to migration need
+ - **User experience**: Clear actionable error with exact command to run
+
+ **Migration Process**:
+ The suggested command will:
+ 1. Extract normalization stats from old model
+ 2. Create new processor configs (preprocessor + postprocessor)
+ 3. Remove normalization layers from model
+ 4. Save migrated model with processor pipeline
+
+ Args:
+ model_path: Path to the model directory needing migration
+ original_error: The error that triggered migration detection (for context)
+
+ Raises:
+ ProcessorMigrationError: Always raised (this method never returns normally)
+ """
+ migration_command = (
+ f"python src/lerobot/processor/migrate_policy_normalization.py --pretrained-path {model_path}"
+ )
+
+ raise ProcessorMigrationError(model_path, migration_command, original_error)
+
+ def __len__(self) -> int:
+ """Returns the number of steps in the pipeline."""
+ return len(self.steps)
+
+ def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TInput, TOutput]:
+ """Retrieves a step or a sub-pipeline by index or slice.
+
+ Args:
+ idx: An integer index or a slice object.
+
+ Returns:
+ A `ProcessorStep` if `idx` is an integer, or a new `DataProcessorPipeline`
+ containing the sliced steps.
+ """
+ if isinstance(idx, slice):
+ # Return a new pipeline instance with the sliced steps.
+ return DataProcessorPipeline(
+ steps=self.steps[idx],
+ name=self.name,
+ to_transition=self.to_transition,
+ to_output=self.to_output,
+ before_step_hooks=self.before_step_hooks.copy(),
+ after_step_hooks=self.after_step_hooks.copy(),
+ )
+ return self.steps[idx]
+
+ def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
+ """Registers a function to be called before each step.
+
+ Args:
+ fn: A callable that accepts the step index and the current transition.
+ """
+ self.before_step_hooks.append(fn)
+
+ def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
+ """Unregisters a 'before_step' hook.
+
+ Args:
+ fn: The exact function object that was previously registered.
+
+ Raises:
+ ValueError: If the hook is not found in the list.
+ """
+ try:
+ self.before_step_hooks.remove(fn)
+ except ValueError:
+ raise ValueError(
+ f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
+ ) from None
+
+ def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
+ """Registers a function to be called after each step.
+
+ Args:
+ fn: A callable that accepts the step index and the current transition.
+ """
+ self.after_step_hooks.append(fn)
+
+ def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
+ """Unregisters an 'after_step' hook.
+
+ Args:
+ fn: The exact function object that was previously registered.
+
+ Raises:
+ ValueError: If the hook is not found in the list.
+ """
+ try:
+ self.after_step_hooks.remove(fn)
+ except ValueError:
+ raise ValueError(
+ f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
+ ) from None
+
+ def reset(self):
+ """Resets the state of all stateful steps in the pipeline."""
+ for step in self.steps:
+ if hasattr(step, "reset"):
+ step.reset()
+
+ def __repr__(self) -> str:
+ """Provides a concise string representation of the pipeline."""
+ step_names = [step.__class__.__name__ for step in self.steps]
+
+ if not step_names:
+ steps_repr = "steps=0: []"
+ elif len(step_names) <= 3:
+ steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]"
+ else:
+ # For long pipelines, show the first, second, and last steps.
+ displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}"
+ steps_repr = f"steps={len(step_names)}: [{displayed}]"
+
+ parts = [f"name='{self.name}'", steps_repr]
+
+ return f"DataProcessorPipeline({', '.join(parts)})"
+
+ def __post_init__(self):
+ """Validates that all provided steps are instances of `ProcessorStep`."""
+ for i, step in enumerate(self.steps):
+ if not isinstance(step, ProcessorStep):
+ raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
+
+ def transform_features(
+ self, initial_features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """Applies feature transformations from all steps sequentially.
+
+ This method propagates a feature description dictionary through each step's
+ `transform_features` method, allowing the pipeline to statically determine
+ the output feature specification without processing any real data.
+
+ Args:
+ initial_features: A dictionary describing the initial features.
+
+ Returns:
+ The final feature description after all transformations.
+ """
+ features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = deepcopy(initial_features)
+
+ for _, step in enumerate(self.steps):
+ out = step.transform_features(features)
+ features = out
+ return features
+
+ # Convenience methods for processing individual parts of a transition.
+ def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
+ """Processes only the observation part of a transition through the pipeline.
+
+ Args:
+ observation: The observation dictionary.
+
+ Returns:
+ The processed observation dictionary.
+ """
+ transition: EnvTransition = create_transition(observation=observation)
+ transformed_transition = self._forward(transition)
+ return transformed_transition[TransitionKey.OBSERVATION]
+
+ def process_action(
+ self, action: PolicyAction | RobotAction | EnvAction
+ ) -> PolicyAction | RobotAction | EnvAction:
+ """Processes only the action part of a transition through the pipeline.
+
+ Args:
+ action: The action data.
+
+ Returns:
+ The processed action.
+ """
+ transition: EnvTransition = create_transition(action=action)
+ transformed_transition = self._forward(transition)
+ return transformed_transition[TransitionKey.ACTION]
+
+ def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor:
+ """Processes only the reward part of a transition through the pipeline.
+
+ Args:
+ reward: The reward value.
+
+ Returns:
+ The processed reward.
+ """
+ transition: EnvTransition = create_transition(reward=reward)
+ transformed_transition = self._forward(transition)
+ return transformed_transition[TransitionKey.REWARD]
+
+ def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor:
+ """Processes only the done flag of a transition through the pipeline.
+
+ Args:
+ done: The done flag.
+
+ Returns:
+ The processed done flag.
+ """
+ transition: EnvTransition = create_transition(done=done)
+ transformed_transition = self._forward(transition)
+ return transformed_transition[TransitionKey.DONE]
+
+ def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor:
+ """Processes only the truncated flag of a transition through the pipeline.
+
+ Args:
+ truncated: The truncated flag.
+
+ Returns:
+ The processed truncated flag.
+ """
+ transition: EnvTransition = create_transition(truncated=truncated)
+ transformed_transition = self._forward(transition)
+ return transformed_transition[TransitionKey.TRUNCATED]
+
+ def process_info(self, info: dict[str, Any]) -> dict[str, Any]:
+ """Processes only the info dictionary of a transition through the pipeline.
+
+ Args:
+ info: The info dictionary.
+
+ Returns:
+ The processed info dictionary.
+ """
+ transition: EnvTransition = create_transition(info=info)
+ transformed_transition = self._forward(transition)
+ return transformed_transition[TransitionKey.INFO]
+
+ def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]:
+ """Processes only the complementary data part of a transition through the pipeline.
+
+ Args:
+ complementary_data: The complementary data dictionary.
+
+ Returns:
+ The processed complementary data dictionary.
+ """
+ transition: EnvTransition = create_transition(complementary_data=complementary_data)
+ transformed_transition = self._forward(transition)
+ return transformed_transition[TransitionKey.COMPLEMENTARY_DATA]
+
+
+# Type aliases for semantic clarity.
+RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
+PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
+
+
+class ObservationProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` that specifically targets the observation in a transition."""
+
+ @abstractmethod
+ def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
+ """Processes an observation dictionary. Subclasses must implement this method.
+
+ Args:
+ observation: The input observation dictionary from the transition.
+
+ Returns:
+ The processed observation dictionary.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `observation` method to the transition's observation."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ observation = new_transition.get(TransitionKey.OBSERVATION)
+ if observation is None or not isinstance(observation, dict):
+ raise ValueError("ObservationProcessorStep requires an observation in the transition.")
+
+ processed_observation = self.observation(observation.copy())
+ new_transition[TransitionKey.OBSERVATION] = processed_observation
+ return new_transition
+
+
+class ActionProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` that specifically targets the action in a transition."""
+
+ @abstractmethod
+ def action(
+ self, action: PolicyAction | RobotAction | EnvAction
+ ) -> PolicyAction | RobotAction | EnvAction:
+ """Processes an action. Subclasses must implement this method.
+
+ Args:
+ action: The input action from the transition.
+
+ Returns:
+ The processed action.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `action` method to the transition's action."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ action = new_transition.get(TransitionKey.ACTION)
+ if action is None:
+ raise ValueError("ActionProcessorStep requires an action in the transition.")
+
+ processed_action = self.action(action)
+ new_transition[TransitionKey.ACTION] = processed_action
+ return new_transition
+
+
+class RobotActionProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` for processing a `RobotAction` (a dictionary)."""
+
+ @abstractmethod
+ def action(self, action: RobotAction) -> RobotAction:
+ """Processes a `RobotAction`. Subclasses must implement this method.
+
+ Args:
+ action: The input `RobotAction` dictionary.
+
+ Returns:
+ The processed `RobotAction`.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `action` method to the transition's action, ensuring it's a `RobotAction`."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ action = new_transition.get(TransitionKey.ACTION)
+ if action is None or not isinstance(action, dict):
+ raise ValueError(f"Action should be a RobotAction type (dict), but got {type(action)}")
+
+ processed_action = self.action(action.copy())
+ new_transition[TransitionKey.ACTION] = processed_action
+ return new_transition
+
+
+class PolicyActionProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` for processing a `PolicyAction` (a tensor or dict of tensors)."""
+
+ @abstractmethod
+ def action(self, action: PolicyAction) -> PolicyAction:
+ """Processes a `PolicyAction`. Subclasses must implement this method.
+
+ Args:
+ action: The input `PolicyAction`.
+
+ Returns:
+ The processed `PolicyAction`.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `action` method to the transition's action, ensuring it's a `PolicyAction`."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ action = new_transition.get(TransitionKey.ACTION)
+ if not isinstance(action, PolicyAction):
+ raise ValueError(f"Action should be a PolicyAction type (tensor), but got {type(action)}")
+
+ processed_action = self.action(action)
+ new_transition[TransitionKey.ACTION] = processed_action
+ return new_transition
+
+
+class RewardProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` that specifically targets the reward in a transition."""
+
+ @abstractmethod
+ def reward(self, reward) -> float | torch.Tensor:
+ """Processes a reward. Subclasses must implement this method.
+
+ Args:
+ reward: The input reward from the transition.
+
+ Returns:
+ The processed reward.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `reward` method to the transition's reward."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ reward = new_transition.get(TransitionKey.REWARD)
+ if reward is None:
+ raise ValueError("RewardProcessorStep requires a reward in the transition.")
+
+ processed_reward = self.reward(reward)
+ new_transition[TransitionKey.REWARD] = processed_reward
+ return new_transition
+
+
+class DoneProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` that specifically targets the 'done' flag in a transition."""
+
+ @abstractmethod
+ def done(self, done) -> bool | torch.Tensor:
+ """Processes a 'done' flag. Subclasses must implement this method.
+
+ Args:
+ done: The input 'done' flag from the transition.
+
+ Returns:
+ The processed 'done' flag.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `done` method to the transition's 'done' flag."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ done = new_transition.get(TransitionKey.DONE)
+ if done is None:
+ raise ValueError("DoneProcessorStep requires a done flag in the transition.")
+
+ processed_done = self.done(done)
+ new_transition[TransitionKey.DONE] = processed_done
+ return new_transition
+
+
+class TruncatedProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` that specifically targets the 'truncated' flag in a transition."""
+
+ @abstractmethod
+ def truncated(self, truncated) -> bool | torch.Tensor:
+ """Processes a 'truncated' flag. Subclasses must implement this method.
+
+ Args:
+ truncated: The input 'truncated' flag from the transition.
+
+ Returns:
+ The processed 'truncated' flag.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `truncated` method to the transition's 'truncated' flag."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ truncated = new_transition.get(TransitionKey.TRUNCATED)
+ if truncated is None:
+ raise ValueError("TruncatedProcessorStep requires a truncated flag in the transition.")
+
+ processed_truncated = self.truncated(truncated)
+ new_transition[TransitionKey.TRUNCATED] = processed_truncated
+ return new_transition
+
+
+class InfoProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` that specifically targets the 'info' dictionary in a transition."""
+
+ @abstractmethod
+ def info(self, info) -> dict[str, Any]:
+ """Processes an 'info' dictionary. Subclasses must implement this method.
+
+ Args:
+ info: The input 'info' dictionary from the transition.
+
+ Returns:
+ The processed 'info' dictionary.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `info` method to the transition's 'info' dictionary."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ info = new_transition.get(TransitionKey.INFO)
+ if info is None or not isinstance(info, dict):
+ raise ValueError("InfoProcessorStep requires an info dictionary in the transition.")
+
+ processed_info = self.info(info.copy())
+ new_transition[TransitionKey.INFO] = processed_info
+ return new_transition
+
+
+class ComplementaryDataProcessorStep(ProcessorStep, ABC):
+ """An abstract `ProcessorStep` that targets the 'complementary_data' in a transition."""
+
+ @abstractmethod
+ def complementary_data(self, complementary_data) -> dict[str, Any]:
+ """Processes a 'complementary_data' dictionary. Subclasses must implement this method.
+
+ Args:
+ complementary_data: The input 'complementary_data' from the transition.
+
+ Returns:
+ The processed 'complementary_data' dictionary.
+ """
+ ...
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Applies the `complementary_data` method to the transition's data."""
+ self._current_transition = transition.copy()
+ new_transition = self._current_transition
+
+ complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA)
+ if complementary_data is None or not isinstance(complementary_data, dict):
+ raise ValueError("ComplementaryDataProcessorStep requires complementary data in the transition.")
+
+ processed_complementary_data = self.complementary_data(complementary_data.copy())
+ new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
+ return new_transition
+
+
+class IdentityProcessorStep(ProcessorStep):
+ """A no-op processor step that returns the input transition and features unchanged.
+
+ This can be useful as a placeholder or for debugging purposes.
+ """
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Returns the transition without modification."""
+ return transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """Returns the features without modification."""
+ return features
diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py
new file mode 100644
index 000000000..25887d414
--- /dev/null
+++ b/src/lerobot/processor/policy_robot_bridge.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import asdict, dataclass
+from typing import Any
+
+import torch
+
+from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
+from lerobot.utils.constants import ACTION
+
+
+@dataclass
+@ProcessorStepRegistry.register("robot_action_to_policy_action_processor")
+class RobotActionToPolicyActionProcessorStep(ActionProcessorStep):
+ """Processor step to map a dictionary to a tensor action."""
+
+ motor_names: list[str]
+
+ def action(self, action: RobotAction) -> PolicyAction:
+ if len(self.motor_names) != len(action):
+ raise ValueError(f"Action must have {len(self.motor_names)} elements, got {len(action)}")
+ return torch.tensor([action[f"{name}.pos"] for name in self.motor_names])
+
+ def get_config(self) -> dict[str, Any]:
+ return asdict(self)
+
+ def transform_features(self, features):
+ features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(len(self.motor_names),)
+ )
+ return features
+
+
+@dataclass
+@ProcessorStepRegistry.register("policy_action_to_robot_action_processor")
+class PolicyActionToRobotActionProcessorStep(ActionProcessorStep):
+ """Processor step to map a policy action to a robot action."""
+
+ motor_names: list[str]
+
+ def action(self, action: PolicyAction) -> RobotAction:
+ if len(self.motor_names) != len(action):
+ raise ValueError(f"Action must have {len(self.motor_names)} elements, got {len(action)}")
+ return {f"{name}.pos": action[i] for i, name in enumerate(self.motor_names)}
+
+ def get_config(self) -> dict[str, Any]:
+ return asdict(self)
+
+ def transform_features(self, features):
+ for name in self.motor_names:
+ features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+ return features
diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py
new file mode 100644
index 000000000..6cae5921f
--- /dev/null
+++ b/src/lerobot/processor/rename_processor.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from copy import deepcopy
+from dataclasses import dataclass, field
+from typing import Any
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+
+from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="rename_observations_processor")
+class RenameObservationsProcessorStep(ObservationProcessorStep):
+ """
+ A processor step that renames keys in an observation dictionary.
+
+ This step is useful for creating a standardized data interface by mapping keys
+ from an environment's format to the format expected by a LeRobot policy or
+ other downstream components.
+
+ Attributes:
+ rename_map: A dictionary mapping from old key names to new key names.
+ Keys present in an observation that are not in this map will
+ be kept with their original names.
+ """
+
+ rename_map: dict[str, str] = field(default_factory=dict)
+
+ def observation(self, observation):
+ processed_obs = {}
+ for key, value in observation.items():
+ if key in self.rename_map:
+ processed_obs[self.rename_map[key]] = value
+ else:
+ processed_obs[key] = value
+
+ return processed_obs
+
+ def get_config(self) -> dict[str, Any]:
+ return {"rename_map": self.rename_map}
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """Transforms:
+ - Each key in the observation that appears in `rename_map` is renamed to its value.
+ - Keys not in `rename_map` remain unchanged.
+ """
+ new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = features.copy()
+ new_features[PipelineFeatureType.OBSERVATION] = {
+ self.rename_map.get(k, k): v for k, v in features[PipelineFeatureType.OBSERVATION].items()
+ }
+ return new_features
+
+
+def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
+ """
+ Renames the top-level keys in a statistics dictionary using a provided mapping.
+
+ This is a helper function typically used to keep normalization statistics
+ consistent with renamed observation or action features. It performs a defensive
+ deep copy to avoid modifying the original `stats` dictionary.
+
+ Args:
+ stats: A nested dictionary of statistics, where top-level keys are
+ feature names (e.g., `{"observation.state": {"mean": 0.5}}`).
+ rename_map: A dictionary mapping old feature names to new feature names.
+
+ Returns:
+ A new statistics dictionary with its top-level keys renamed. Returns an
+ empty dictionary if the input `stats` is empty.
+ """
+ if not stats:
+ return {}
+ renamed: dict[str, dict[str, Any]] = {}
+ for old_key, sub_stats in stats.items():
+ new_key = rename_map.get(old_key, old_key)
+ renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {}
+ return renamed
diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py
new file mode 100644
index 000000000..2ef89c107
--- /dev/null
+++ b/src/lerobot/processor/tokenizer_processor.py
@@ -0,0 +1,270 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script defines a processor for tokenizing natural language instructions from an environment transition.
+
+It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
+token IDs and attention masks, which are then added to the observation dictionary.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any
+
+import torch
+
+from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
+from lerobot.utils.import_utils import _transformers_available
+
+from .core import EnvTransition, TransitionKey
+from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
+
+# Conditional import for type checking and lazy loading
+if TYPE_CHECKING or _transformers_available:
+ from transformers import AutoTokenizer
+else:
+ AutoTokenizer = None
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="tokenizer_processor")
+class TokenizerProcessorStep(ObservationProcessorStep):
+ """
+ Processor step to tokenize a natural language task description.
+
+ This step extracts a task string from the `complementary_data` of an `EnvTransition`,
+ tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
+ token IDs and attention mask to the `observation` dictionary.
+
+ Requires the `transformers` library to be installed.
+
+ Attributes:
+ tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased").
+ tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
+ max_length: The maximum length to pad or truncate sequences to.
+ task_key: The key in `complementary_data` where the task string is stored.
+ padding_side: The side to pad on ('left' or 'right').
+ padding: The padding strategy ('max_length', 'longest', etc.).
+ truncation: Whether to truncate sequences longer than `max_length`.
+ input_tokenizer: The internal tokenizer instance, loaded during initialization.
+ """
+
+ tokenizer_name: str | None = None
+ tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
+ max_length: int = 512
+ task_key: str = "task"
+ padding_side: str = "right"
+ padding: str = "max_length"
+ truncation: bool = True
+
+ # Internal tokenizer instance (not part of the config)
+ input_tokenizer: Any = field(default=None, init=False, repr=False)
+
+ def __post_init__(self):
+ """
+ Initializes the tokenizer after the dataclass is created.
+
+ It checks for the availability of the `transformers` library and loads the tokenizer
+ either from a provided object or by name from the Hugging Face Hub.
+
+ Raises:
+ ImportError: If the `transformers` library is not installed.
+ ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
+ """
+ if not _transformers_available:
+ raise ImportError(
+ "The 'transformers' library is not installed. "
+ "Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep."
+ )
+
+ if self.tokenizer is not None:
+ # Use provided tokenizer object directly
+ self.input_tokenizer = self.tokenizer
+ elif self.tokenizer_name is not None:
+ if AutoTokenizer is None:
+ raise ImportError("AutoTokenizer is not available")
+ self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
+ else:
+ raise ValueError(
+ "Either 'tokenizer' or 'tokenizer_name' must be provided. "
+ "Pass a tokenizer object directly or a tokenizer name to auto-load."
+ )
+
+ def get_task(self, transition: EnvTransition) -> list[str] | None:
+ """
+ Extracts the task description(s) from the transition's complementary data.
+
+ Args:
+ transition: The environment transition.
+
+ Returns:
+ A list of task strings, or None if the task key is not found or the value is None.
+ """
+ complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
+ if complementary_data is None:
+ raise ValueError("Complementary data is None so no task can be extracted from it")
+
+ task = complementary_data[self.task_key]
+ if task is None:
+ raise ValueError("Task extracted from Complementary data is None")
+
+ # Standardize to a list of strings for the tokenizer
+ if isinstance(task, str):
+ return [task]
+ elif isinstance(task, list) and all(isinstance(t, str) for t in task):
+ return task
+
+ return None
+
+ def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
+ """
+ Tokenizes the task description and adds it to the observation dictionary.
+
+ This method retrieves the task, tokenizes it, moves the resulting tensors to the
+ same device as other data in the transition, and updates the observation.
+
+ Args:
+ observation: The original observation dictionary.
+
+ Returns:
+ The updated observation dictionary including token IDs and an attention mask.
+ """
+ task = self.get_task(self.transition)
+ if task is None:
+ raise ValueError("Task cannot be None")
+
+ # Tokenize the task (this will create CPU tensors)
+ tokenized_prompt = self._tokenize_text(task)
+
+ # Detect the device from existing tensors in the transition to ensure consistency
+ target_device = self._detect_device(self.transition)
+
+ # Move new tokenized tensors to the detected device
+ if target_device is not None:
+ tokenized_prompt = {
+ k: v.to(target_device) if isinstance(v, torch.Tensor) else v
+ for k, v in tokenized_prompt.items()
+ }
+
+ # Create a new observation dict to avoid modifying the original in place
+ new_observation = dict(observation)
+
+ # Add tokenized data to the observation
+ new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
+ new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
+
+ return new_observation
+
+ def _detect_device(self, transition: EnvTransition) -> torch.device | None:
+ """
+ Detects the torch.device from existing tensors in the transition.
+
+ It checks tensors in the observation dictionary first, then the action tensor.
+
+ Args:
+ transition: The environment transition.
+
+ Returns:
+ The detected `torch.device`, or None if no tensors are found.
+ """
+ # Check observation tensors first (most likely place to find tensors)
+ observation = transition.get(TransitionKey.OBSERVATION)
+ if observation:
+ for value in observation.values():
+ if isinstance(value, torch.Tensor):
+ return value.device
+
+ # Fallback to checking the action tensor
+ action = transition.get(TransitionKey.ACTION)
+ if isinstance(action, torch.Tensor):
+ return action.device
+
+ return None # No tensors found, default will be CPU
+
+ def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
+ """
+ A wrapper around the tokenizer call.
+
+ Args:
+ text: A string or list of strings to tokenize.
+
+ Returns:
+ A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors.
+ """
+ return self.input_tokenizer(
+ text,
+ max_length=self.max_length,
+ truncation=self.truncation,
+ padding=self.padding,
+ padding_side=self.padding_side,
+ return_tensors="pt",
+ )
+
+ def get_config(self) -> dict[str, Any]:
+ """
+ Returns the serializable configuration of the processor.
+
+ Note: The tokenizer object itself is not serialized. If the processor was initialized
+ with a tokenizer name, that name will be included in the config.
+
+ Returns:
+ A dictionary with the processor's configuration parameters.
+ """
+ config = {
+ "max_length": self.max_length,
+ "task_key": self.task_key,
+ "padding_side": self.padding_side,
+ "padding": self.padding,
+ "truncation": self.truncation,
+ }
+
+ # Only save tokenizer_name if it was used to create the tokenizer
+ if self.tokenizer_name is not None and self.tokenizer is None:
+ config["tokenizer_name"] = self.tokenizer_name
+
+ return config
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Adds feature definitions for the language tokens and attention mask.
+
+ This updates the policy features dictionary to include the new data added to the
+ observation, ensuring downstream components are aware of their shape and type.
+
+ Args:
+ features: The dictionary of existing policy features.
+
+ Returns:
+ The updated dictionary of policy features.
+ """
+ # Add a feature for the token IDs if it doesn't already exist
+ if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
+ features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature(
+ type=FeatureType.LANGUAGE, shape=(self.max_length,)
+ )
+
+ # Add a feature for the attention mask if it doesn't already exist
+ if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
+ features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
+ type=FeatureType.LANGUAGE, shape=(self.max_length,)
+ )
+
+ return features
diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/rl/actor.py
similarity index 80%
rename from src/lerobot/scripts/rl/actor.py
rename to src/lerobot/rl/actor.py
index 0e96d3354..54d0fba69 100644
--- a/src/lerobot/scripts/rl/actor.py
+++ b/src/lerobot/rl/actor.py
@@ -24,7 +24,7 @@ Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
-python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
+python -m lerobot.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
```
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
@@ -35,7 +35,7 @@ gamepad to take control of the robot during training. Initially intervene freque
reduce interventions as the policy improves.
**WORKFLOW**:
-1. Determine robot workspace bounds using `find_joint_limits.py`
+1. Determine robot workspace bounds using `lerobot-find-joint-limits`
2. Record demonstrations with `gym_manipulator.py` in record mode
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
4. Start the learner server with the training configuration
@@ -62,20 +62,21 @@ from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
+from lerobot.processor import TransitionKey
+from lerobot.rl.process import ProcessSignalHandler
+from lerobot.rl.queue import get_last_item_from_queue
from lerobot.robots import so100_follower # noqa: F401
-from lerobot.scripts.rl import learner_service
-from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
+from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
+ grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
-from lerobot.utils.process import ProcessSignalHandler
-from lerobot.utils.queue import get_last_item_from_queue
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.transition import (
@@ -89,12 +90,14 @@ from lerobot.utils.utils import (
init_logging,
)
-ACTOR_SHUTDOWN_TIMEOUT = 30
+from .gym_manipulator import (
+ create_transition,
+ make_processors,
+ make_robot_env,
+ step_env_and_process_transition,
+)
-
-#################################################
-# Main entry point #
-#################################################
+# Main entry point
@parser.wrap()
@@ -201,9 +204,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
logging.info("[ACTOR] queues closed")
-#################################################
-# Core algorithm functions #
-#################################################
+# Core algorithm functions
def act_with_policy(
@@ -236,7 +237,8 @@ def act_with_policy(
logging.info("make_env online")
- online_env = make_robot_env(cfg=cfg.env)
+ online_env, teleop_device = make_robot_env(cfg=cfg.env)
+ env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
@@ -257,6 +259,12 @@ def act_with_policy(
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
+ env_processor.reset()
+ action_processor.reset()
+
+ # Process initial observation
+ transition = create_transition(observation=obs, info=info)
+ transition = env_processor(transition)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -274,50 +282,76 @@ def act_with_policy(
logging.info("[ACTOR] Shutting down act_with_policy")
return
- if interaction_step >= cfg.policy.online_step_before_learning:
- # Time policy inference and check if it meets FPS requirement
- with policy_timer:
- action = policy.select_action(batch=obs)
- policy_fps = policy_timer.fps_last
+ observation = {
+ k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
+ }
- log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
+ # Time policy inference and check if it meets FPS requirement
+ with policy_timer:
+ # Extract observation from transition for policy
+ action = policy.select_action(batch=observation)
+ policy_fps = policy_timer.fps_last
- else:
- action = online_env.action_space.sample()
+ log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
- next_obs, reward, done, truncated, info = online_env.step(action)
+ # Use the new step function
+ new_transition = step_env_and_process_transition(
+ env=online_env,
+ transition=transition,
+ action=action,
+ env_processor=env_processor,
+ action_processor=action_processor,
+ )
+
+ # Extract values from processed transition
+ next_observation = {
+ k: v
+ for k, v in new_transition[TransitionKey.OBSERVATION].items()
+ if k in cfg.policy.input_features
+ }
+
+ # Teleop action is the action that was executed in the environment
+ # It is either the action from the teleop device or the action from the policy
+ executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
+
+ reward = new_transition[TransitionKey.REWARD]
+ done = new_transition.get(TransitionKey.DONE, False)
+ truncated = new_transition.get(TransitionKey.TRUNCATED, False)
sum_reward_episode += float(reward)
- # Increment total steps counter for intervention rate
episode_total_steps += 1
- # NOTE: We override the action if the intervention is True, because the action applied is the intervention action
- if "is_intervention" in info and info["is_intervention"]:
- # NOTE: The action space for demonstration before hand is with the full action space
- # but sometimes for example we want to deactivate the gripper
- action = info["action_intervention"]
+ # Check for intervention from transition info
+ intervention_info = new_transition[TransitionKey.INFO]
+ if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
episode_intervention = True
- # Increment intervention steps counter
episode_intervention_steps += 1
+ complementary_info = {
+ "discrete_penalty": torch.tensor(
+ [new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
+ ),
+ }
+ # Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
Transition(
- state=obs,
- action=action,
+ state=observation,
+ action=executed_action,
reward=reward,
- next_state=next_obs,
+ next_state=next_observation,
done=done,
- truncated=truncated, # TODO: (azouitine) Handle truncation properly
- complementary_info=info,
+ truncated=truncated,
+ complementary_info=complementary_info,
)
)
- # assign obs to the next obs and continue the rollout
- obs = next_obs
+
+ # Update transition for next iteration
+ transition = new_transition
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
- update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
+ update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -347,21 +381,27 @@ def act_with_policy(
)
)
- # Reset intervention counters
+ # Reset intervention counters and environment
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
+
+ # Reset environment and processors
obs, info = online_env.reset()
+ env_processor.reset()
+ action_processor.reset()
+
+ # Process initial observation
+ transition = create_transition(observation=obs, info=info)
+ transition = env_processor(transition)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.env.fps - dt_time)
-#################################################
-# Communication Functions - Group all gRPC/messaging functions #
-#################################################
+# Communication Functions - Group all gRPC/messaging functions
def establish_learner_connection(
@@ -399,8 +439,6 @@ def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
- import json
-
"""
Returns a client for the learner service.
@@ -408,34 +446,9 @@ def learner_service_client(
So we need to create only one client and reuse it.
"""
- service_config = {
- "methodConfig": [
- {
- "name": [{}], # Applies to ALL methods in ALL services
- "retryPolicy": {
- "maxAttempts": 5, # Max retries (total attempts = 5)
- "initialBackoff": "0.1s", # First retry after 0.1s
- "maxBackoff": "2s", # Max wait time between retries
- "backoffMultiplier": 2, # Exponential backoff factor
- "retryableStatusCodes": [
- "UNAVAILABLE",
- "DEADLINE_EXCEEDED",
- ], # Retries on network failures
- },
- }
- ]
- }
-
- service_config_json = json.dumps(service_config)
-
channel = grpc.insecure_channel(
f"{host}:{port}",
- options=[
- ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
- ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
- ("grpc.enable_retries", 1),
- ("grpc.service_config", service_config_json),
- ],
+ grpc_channel_options(),
)
stub = services_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
@@ -633,23 +646,39 @@ def interactions_stream(
return services_pb2.Empty()
-#################################################
-# Policy functions #
-#################################################
+# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
- state_dict = bytes_to_state_dict(bytes_state_dict)
- state_dict = move_state_dict_to_device(state_dict, device=device)
- policy.load_state_dict(state_dict)
+ state_dicts = bytes_to_state_dict(bytes_state_dict)
+
+ # TODO: check encoder parameter synchronization possible issues:
+ # 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
+ # instead of the updated encoder params from critic (which is optimized separately)
+ # 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params
+ # 3. Need to handle encoder params correctly for both actor and discrete_critic
+ # Potential fixes:
+ # - Send critic's encoder state when shared_encoder=True
+ # - Skip encoder params entirely when freeze_vision_encoder=True
+ # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
+
+ # Load actor state dict
+ actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
+ policy.actor.load_state_dict(actor_state_dict)
+
+ # Load discrete critic if present
+ if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
+ discrete_critic_state_dict = move_state_dict_to_device(
+ state_dicts["discrete_critic"], device=device
+ )
+ policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
+ logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
-#################################################
-# Utilities functions #
-#################################################
+# Utilities functions
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
diff --git a/src/lerobot/utils/buffer.py b/src/lerobot/rl/buffer.py
similarity index 94%
rename from src/lerobot/utils/buffer.py
rename to src/lerobot/rl/buffer.py
index 7f8d989dd..917e4e2cc 100644
--- a/src/lerobot/utils/buffer.py
+++ b/src/lerobot/rl/buffer.py
@@ -15,14 +15,16 @@
# limitations under the License.
import functools
+from collections.abc import Callable, Sequence
from contextlib import suppress
-from typing import Callable, Sequence, TypedDict
+from typing import TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD
from lerobot.utils.transition import Transition
@@ -174,7 +176,7 @@ class ReplayBuffer:
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
- elif isinstance(value, (int, float)):
+ elif isinstance(value, (int | float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
@@ -221,7 +223,7 @@ class ReplayBuffer:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
- elif isinstance(value, (int, float)):
+ elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
@@ -239,7 +241,7 @@ class ReplayBuffer:
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
- image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
+ image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Create batched state and next_state
batch_state = {}
@@ -465,7 +467,7 @@ class ReplayBuffer:
if list_transition:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
- first_action = first_transition["action"].to(device)
+ first_action = first_transition[ACTION].to(device)
# Get complementary info if available
first_complementary_info = None
@@ -490,7 +492,7 @@ class ReplayBuffer:
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
- action = data["action"]
+ action = data[ACTION]
replay_buffer.add(
state=data["state"],
@@ -528,12 +530,12 @@ class ReplayBuffer:
# Add "action"
sample_action = self.actions[0]
- act_info = guess_feature_info(t=sample_action, name="action")
- features["action"] = act_info
+ act_info = guess_feature_info(t=sample_action, name=ACTION)
+ features[ACTION] = act_info
# Add "reward" and "done"
- features["next.reward"] = {"dtype": "float32", "shape": (1,)}
- features["next.done"] = {"dtype": "bool", "shape": (1,)}
+ features[REWARD] = {"dtype": "float32", "shape": (1,)}
+ features[DONE] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.states:
@@ -564,10 +566,7 @@ class ReplayBuffer:
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
# Convert transitions into episodes and frames
- episode_index = 0
- lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
- frame_idx_in_episode = 0
for idx in range(self.size):
actual_idx = (self.position - self.size + idx) % self.capacity
@@ -578,9 +577,10 @@ class ReplayBuffer:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
- frame_dict["action"] = self.actions[actual_idx].cpu()
- frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
- frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
+ frame_dict[ACTION] = self.actions[actual_idx].cpu()
+ frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
+ frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
+ frame_dict["task"] = task_name
# Add complementary_info if available
if self.has_complementary_info:
@@ -596,19 +596,11 @@ class ReplayBuffer:
frame_dict[f"complementary_info.{key}"] = val
# Add to the dataset's buffer
- lerobot_dataset.add_frame(frame_dict, task=task_name)
-
- # Move to next frame
- frame_idx_in_episode += 1
+ lerobot_dataset.add_frame(frame_dict)
# If we reached an episode boundary, call save_episode, reset counters
if self.dones[actual_idx] or self.truncateds[actual_idx]:
lerobot_dataset.save_episode()
- episode_index += 1
- frame_idx_in_episode = 0
- lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
- episode_index=episode_index
- )
# Save any remaining frames in the buffer
if lerobot_dataset.episode_buffer["size"] > 0:
@@ -656,7 +648,7 @@ class ReplayBuffer:
# Check if the dataset has "next.done" key
sample = dataset[0]
- has_done_key = "next.done" in sample
+ has_done_key = DONE in sample
# Check for complementary_info keys
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
@@ -676,14 +668,14 @@ class ReplayBuffer:
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
- action = current_sample["action"].unsqueeze(0) # Add batch dimension
+ action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
- reward = float(current_sample["next.reward"].item()) # ensure float
+ reward = float(current_sample[REWARD].item()) # ensure float
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
if has_done_key:
- done = bool(current_sample["next.done"].item()) # ensure bool
+ done = bool(current_sample[DONE].item()) # ensure bool
else:
# If this is the last frame or if next frame is in a different episode, mark as done
done = False
@@ -796,8 +788,8 @@ def concatenate_batch_transitions(
}
# Concatenate basic fields
- left_batch_transitions["action"] = torch.cat(
- [left_batch_transitions["action"], right_batch_transition["action"]], dim=0
+ left_batch_transitions[ACTION] = torch.cat(
+ [left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
diff --git a/src/lerobot/scripts/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py
similarity index 91%
rename from src/lerobot/scripts/rl/crop_dataset_roi.py
rename to src/lerobot/rl/crop_dataset_roi.py
index 0b71b5363..4345fed3c 100644
--- a/src/lerobot/scripts/rl/crop_dataset_roi.py
+++ b/src/lerobot/rl/crop_dataset_roi.py
@@ -18,7 +18,6 @@ import argparse
import json
from copy import deepcopy
from pathlib import Path
-from typing import Dict, Tuple
import cv2
import torch
@@ -26,6 +25,7 @@ import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.utils.constants import DONE, REWARD
def select_rect_roi(img):
@@ -160,12 +160,12 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
return image_dict
-def convert_lerobot_dataset_to_cropper_lerobot_dataset(
+def convert_lerobot_dataset_to_cropped_lerobot_dataset(
original_dataset: LeRobotDataset,
- crop_params_dict: Dict[str, Tuple[int, int, int, int]],
+ crop_params_dict: dict[str, tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
- resize_size: Tuple[int, int] = (128, 128),
+ resize_size: tuple[int, int] = (128, 128),
push_to_hub: bool = False,
task: str = "",
) -> LeRobotDataset:
@@ -190,7 +190,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
- fps=original_dataset.fps,
+ fps=int(original_dataset.fps),
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
@@ -213,7 +213,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
continue
- if key in ("next.done", "next.reward"):
+ if key in (DONE, REWARD):
# if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0)
@@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
value = value.unsqueeze(0)
new_frame[key] = value
- new_dataset.add_frame(new_frame, task=task)
+ new_frame["task"] = task
+ new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
@@ -274,6 +275,12 @@ if __name__ == "__main__":
default="",
help="The natural language task to describe the dataset.",
)
+ parser.add_argument(
+ "--new-repo-id",
+ type=str,
+ default=None,
+ help="The repository id for the new cropped and resized dataset. If not provided, it defaults to `repo_id` + '_cropped_resized'.",
+ )
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
@@ -293,10 +300,16 @@ if __name__ == "__main__":
for key, roi in rois.items():
print(f"{key}: {roi}")
- new_repo_id = args.repo_id + "_cropped_resized"
- new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
+ new_repo_id = args.new_repo_id if args.new_repo_id else args.repo_id + "_cropped_resized"
- cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
+ if args.new_repo_id:
+ new_dataset_name = args.new_repo_id.split("/")[-1]
+ # Parent 1: HF user, Parent 2: HF LeRobot Home
+ new_dataset_root = dataset.root.parent.parent / new_dataset_name
+ else:
+ new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
+
+ cropped_resized_dataset = convert_lerobot_dataset_to_cropped_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
diff --git a/src/lerobot/scripts/rl/eval_policy.py b/src/lerobot/rl/eval_policy.py
similarity index 97%
rename from src/lerobot/scripts/rl/eval_policy.py
rename to src/lerobot/rl/eval_policy.py
index aa97483b6..7cec66800 100644
--- a/src/lerobot/scripts/rl/eval_policy.py
+++ b/src/lerobot/rl/eval_policy.py
@@ -25,12 +25,13 @@ from lerobot.robots import ( # noqa: F401
make_robot_from_config,
so100_follower,
)
-from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.teleoperators import (
gamepad, # noqa: F401
so101_leader, # noqa: F401
)
+from .gym_manipulator import make_robot_env
+
logging.basicConfig(level=logging.INFO)
diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py
new file mode 100644
index 000000000..ad36f1b36
--- /dev/null
+++ b/src/lerobot/rl/gym_manipulator.py
@@ -0,0 +1,770 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+from dataclasses import dataclass
+from typing import Any
+
+import gymnasium as gym
+import numpy as np
+import torch
+
+from lerobot.cameras import opencv # noqa: F401
+from lerobot.configs import parser
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.envs.configs import HILSerlRobotEnvConfig
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ AddTeleopActionAsComplimentaryDataStep,
+ AddTeleopEventsAsInfoStep,
+ DataProcessorPipeline,
+ DeviceProcessorStep,
+ EnvTransition,
+ GripperPenaltyProcessorStep,
+ ImageCropResizeProcessorStep,
+ InterventionActionProcessorStep,
+ JointVelocityProcessorStep,
+ MapDeltaActionToRobotActionStep,
+ MapTensorToDeltaActionDictStep,
+ MotorCurrentProcessorStep,
+ Numpy2TorchActionProcessorStep,
+ RewardClassifierProcessorStep,
+ RobotActionToPolicyActionProcessorStep,
+ TimeLimitProcessorStep,
+ Torch2NumpyActionProcessorStep,
+ TransitionKey,
+ VanillaObservationProcessorStep,
+ create_transition,
+)
+from lerobot.processor.converters import identity_transition
+from lerobot.robots import ( # noqa: F401
+ RobotConfig,
+ make_robot_from_config,
+ so100_follower,
+)
+from lerobot.robots.robot import Robot
+from lerobot.robots.so100_follower.robot_kinematic_processor import (
+ EEBoundsAndSafety,
+ EEReferenceAndDelta,
+ ForwardKinematicsJointsToEEObservation,
+ GripperVelocityToJoint,
+ InverseKinematicsRLStep,
+)
+from lerobot.teleoperators import (
+ gamepad, # noqa: F401
+ keyboard, # noqa: F401
+ make_teleoperator_from_config,
+ so101_leader, # noqa: F401
+)
+from lerobot.teleoperators.teleoperator import Teleoperator
+from lerobot.teleoperators.utils import TeleopEvents
+from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
+from lerobot.utils.robot_utils import busy_wait
+from lerobot.utils.utils import log_say
+
+logging.basicConfig(level=logging.INFO)
+
+
+@dataclass
+class DatasetConfig:
+ """Configuration for dataset creation and management."""
+
+ repo_id: str
+ task: str
+ root: str | None = None
+ num_episodes_to_record: int = 5
+ replay_episode: int | None = None
+ push_to_hub: bool = False
+
+
+@dataclass
+class GymManipulatorConfig:
+ """Main configuration for gym manipulator environment."""
+
+ env: HILSerlRobotEnvConfig
+ dataset: DatasetConfig
+ mode: str | None = None # Either "record", "replay", None
+ device: str = "cpu"
+
+
+def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None:
+ """Reset robot arm to target position using smooth trajectory."""
+ current_position_dict = robot_arm.bus.sync_read("Present_Position")
+ current_position = np.array(
+ [current_position_dict[name] for name in current_position_dict], dtype=np.float32
+ )
+ trajectory = torch.from_numpy(
+ np.linspace(current_position, target_position, 50)
+ ) # NOTE: 30 is just an arbitrary number
+ for pose in trajectory:
+ action_dict = dict(zip(current_position_dict, pose, strict=False))
+ robot_arm.bus.sync_write("Goal_Position", action_dict)
+ busy_wait(0.015)
+
+
+class RobotEnv(gym.Env):
+ """Gym environment for robotic control with human intervention support."""
+
+ def __init__(
+ self,
+ robot,
+ use_gripper: bool = False,
+ display_cameras: bool = False,
+ reset_pose: list[float] | None = None,
+ reset_time_s: float = 5.0,
+ ) -> None:
+ """Initialize robot environment with configuration options.
+
+ Args:
+ robot: Robot interface for hardware communication.
+ use_gripper: Whether to include gripper in action space.
+ display_cameras: Whether to show camera feeds during execution.
+ reset_pose: Joint positions for environment reset.
+ reset_time_s: Time to wait during reset.
+ """
+ super().__init__()
+
+ self.robot = robot
+ self.display_cameras = display_cameras
+
+ # Connect to the robot if not already connected.
+ if not self.robot.is_connected:
+ self.robot.connect()
+
+ # Episode tracking.
+ self.current_step = 0
+ self.episode_data = None
+
+ self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors]
+ self._image_keys = self.robot.cameras.keys()
+
+ self.reset_pose = reset_pose
+ self.reset_time_s = reset_time_s
+
+ self.use_gripper = use_gripper
+
+ self._joint_names = list(self.robot.bus.motors.keys())
+ self._raw_joint_positions = None
+
+ self._setup_spaces()
+
+ def _get_observation(self) -> dict[str, Any]:
+ """Get current robot observation including joint positions and camera images."""
+ obs_dict = self.robot.get_observation()
+ raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names}
+ joint_positions = np.array([raw_joint_joint_position[f"{name}.pos"] for name in self._joint_names])
+
+ images = {key: obs_dict[key] for key in self._image_keys}
+
+ return {"agent_pos": joint_positions, "pixels": images, **raw_joint_joint_position}
+
+ def _setup_spaces(self) -> None:
+ """Configure observation and action spaces based on robot capabilities."""
+ current_observation = self._get_observation()
+
+ observation_spaces = {}
+
+ # Define observation spaces for images and other states.
+ if current_observation is not None and "pixels" in current_observation:
+ prefix = OBS_IMAGES
+ observation_spaces = {
+ f"{prefix}.{key}": gym.spaces.Box(
+ low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
+ )
+ for key in current_observation["pixels"]
+ }
+
+ if current_observation is not None:
+ agent_pos = current_observation["agent_pos"]
+ observation_spaces[OBS_STATE] = gym.spaces.Box(
+ low=0,
+ high=10,
+ shape=agent_pos.shape,
+ dtype=np.float32,
+ )
+
+ self.observation_space = gym.spaces.Dict(observation_spaces)
+
+ # Define the action space for joint positions along with setting an intervention flag.
+ action_dim = 3
+ bounds = {}
+ bounds["min"] = -np.ones(action_dim)
+ bounds["max"] = np.ones(action_dim)
+
+ if self.use_gripper:
+ action_dim += 1
+ bounds["min"] = np.concatenate([bounds["min"], [0]])
+ bounds["max"] = np.concatenate([bounds["max"], [2]])
+
+ self.action_space = gym.spaces.Box(
+ low=bounds["min"],
+ high=bounds["max"],
+ shape=(action_dim,),
+ dtype=np.float32,
+ )
+
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """Reset environment to initial state.
+
+ Args:
+ seed: Random seed for reproducibility.
+ options: Additional reset options.
+
+ Returns:
+ Tuple of (observation, info) dictionaries.
+ """
+ # Reset the robot
+ # self.robot.reset()
+ start_time = time.perf_counter()
+ if self.reset_pose is not None:
+ log_say("Reset the environment.", play_sounds=True)
+ reset_follower_position(self.robot, np.array(self.reset_pose))
+ log_say("Reset the environment done.", play_sounds=True)
+
+ busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
+
+ super().reset(seed=seed, options=options)
+
+ # Reset episode tracking variables.
+ self.current_step = 0
+ self.episode_data = None
+ obs = self._get_observation()
+ self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
+ return obs, {TeleopEvents.IS_INTERVENTION: False}
+
+ def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
+ """Execute one environment step with given action."""
+ joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())}
+
+ self.robot.send_action(joint_targets_dict)
+
+ obs = self._get_observation()
+
+ self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
+
+ if self.display_cameras:
+ self.render()
+
+ self.current_step += 1
+
+ reward = 0.0
+ terminated = False
+ truncated = False
+
+ return (
+ obs,
+ reward,
+ terminated,
+ truncated,
+ {TeleopEvents.IS_INTERVENTION: False},
+ )
+
+ def render(self) -> None:
+ """Display robot camera feeds."""
+ import cv2
+
+ current_observation = self._get_observation()
+ if current_observation is not None:
+ image_keys = [key for key in current_observation if "image" in key]
+
+ for key in image_keys:
+ cv2.imshow(key, cv2.cvtColor(current_observation[key].numpy(), cv2.COLOR_RGB2BGR))
+ cv2.waitKey(1)
+
+ def close(self) -> None:
+ """Close environment and disconnect robot."""
+ if self.robot.is_connected:
+ self.robot.disconnect()
+
+ def get_raw_joint_positions(self) -> dict[str, float]:
+ """Get raw joint positions."""
+ return self._raw_joint_positions
+
+
+def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
+ """Create robot environment from configuration.
+
+ Args:
+ cfg: Environment configuration.
+
+ Returns:
+ Tuple of (gym environment, teleoperator device).
+ """
+ # Check if this is a GymHIL simulation environment
+ if cfg.name == "gym_hil":
+ assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
+ import gym_hil # noqa: F401
+
+ # Extract gripper settings with defaults
+ use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True
+ gripper_penalty = cfg.processor.gripper.gripper_penalty if cfg.processor.gripper is not None else 0.0
+
+ env = gym.make(
+ f"gym_hil/{cfg.task}",
+ image_obs=True,
+ render_mode="human",
+ use_gripper=use_gripper,
+ gripper_penalty=gripper_penalty,
+ )
+
+ return env, None
+
+ # Real robot environment
+ assert cfg.robot is not None, "Robot config must be provided for real robot environment"
+ assert cfg.teleop is not None, "Teleop config must be provided for real robot environment"
+
+ robot = make_robot_from_config(cfg.robot)
+ teleop_device = make_teleoperator_from_config(cfg.teleop)
+ teleop_device.connect()
+
+ # Create base environment with safe defaults
+ use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True
+ display_cameras = (
+ cfg.processor.observation.display_cameras if cfg.processor.observation is not None else False
+ )
+ reset_pose = cfg.processor.reset.fixed_reset_joint_positions if cfg.processor.reset is not None else None
+
+ env = RobotEnv(
+ robot=robot,
+ use_gripper=use_gripper,
+ display_cameras=display_cameras,
+ reset_pose=reset_pose,
+ )
+
+ return env, teleop_device
+
+
+def make_processors(
+ env: gym.Env, teleop_device: Teleoperator | None, cfg: HILSerlRobotEnvConfig, device: str = "cpu"
+) -> tuple[
+ DataProcessorPipeline[EnvTransition, EnvTransition], DataProcessorPipeline[EnvTransition, EnvTransition]
+]:
+ """Create environment and action processors.
+
+ Args:
+ env: Robot environment instance.
+ teleop_device: Teleoperator device for intervention.
+ cfg: Processor configuration.
+ device: Target device for computations.
+
+ Returns:
+ Tuple of (environment processor, action processor).
+ """
+ terminate_on_success = (
+ cfg.processor.reset.terminate_on_success if cfg.processor.reset is not None else True
+ )
+
+ if cfg.name == "gym_hil":
+ action_pipeline_steps = [
+ InterventionActionProcessorStep(terminate_on_success=terminate_on_success),
+ Torch2NumpyActionProcessorStep(),
+ ]
+
+ env_pipeline_steps = [
+ Numpy2TorchActionProcessorStep(),
+ VanillaObservationProcessorStep(),
+ AddBatchDimensionProcessorStep(),
+ DeviceProcessorStep(device=device),
+ ]
+
+ return DataProcessorPipeline(
+ steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
+ ), DataProcessorPipeline(
+ steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
+ )
+
+ # Full processor pipeline for real robot environment
+ # Get robot and motor information for kinematics
+ motor_names = list(env.robot.bus.motors.keys())
+
+ # Set up kinematics solver if inverse kinematics is configured
+ kinematics_solver = None
+ if cfg.processor.inverse_kinematics is not None:
+ kinematics_solver = RobotKinematics(
+ urdf_path=cfg.processor.inverse_kinematics.urdf_path,
+ target_frame_name=cfg.processor.inverse_kinematics.target_frame_name,
+ joint_names=motor_names,
+ )
+
+ env_pipeline_steps = [VanillaObservationProcessorStep()]
+
+ if cfg.processor.observation is not None:
+ if cfg.processor.observation.add_joint_velocity_to_observation:
+ env_pipeline_steps.append(JointVelocityProcessorStep(dt=1.0 / cfg.fps))
+ if cfg.processor.observation.add_current_to_observation:
+ env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
+
+ if kinematics_solver is not None:
+ env_pipeline_steps.append(
+ ForwardKinematicsJointsToEEObservation(
+ kinematics=kinematics_solver,
+ motor_names=motor_names,
+ )
+ )
+
+ if cfg.processor.image_preprocessing is not None:
+ env_pipeline_steps.append(
+ ImageCropResizeProcessorStep(
+ crop_params_dict=cfg.processor.image_preprocessing.crop_params_dict,
+ resize_size=cfg.processor.image_preprocessing.resize_size,
+ )
+ )
+
+ # Add time limit processor if reset config exists
+ if cfg.processor.reset is not None:
+ env_pipeline_steps.append(
+ TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
+ )
+
+ # Add gripper penalty processor if gripper config exists and enabled
+ if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
+ env_pipeline_steps.append(
+ GripperPenaltyProcessorStep(
+ penalty=cfg.processor.gripper.gripper_penalty,
+ max_gripper_pos=cfg.processor.max_gripper_pos,
+ )
+ )
+
+ if (
+ cfg.processor.reward_classifier is not None
+ and cfg.processor.reward_classifier.pretrained_path is not None
+ ):
+ env_pipeline_steps.append(
+ RewardClassifierProcessorStep(
+ pretrained_path=cfg.processor.reward_classifier.pretrained_path,
+ device=device,
+ success_threshold=cfg.processor.reward_classifier.success_threshold,
+ success_reward=cfg.processor.reward_classifier.success_reward,
+ terminate_on_success=terminate_on_success,
+ )
+ )
+
+ env_pipeline_steps.append(AddBatchDimensionProcessorStep())
+ env_pipeline_steps.append(DeviceProcessorStep(device=device))
+
+ action_pipeline_steps = [
+ AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device),
+ AddTeleopEventsAsInfoStep(teleop_device=teleop_device),
+ InterventionActionProcessorStep(
+ use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False,
+ terminate_on_success=terminate_on_success,
+ ),
+ ]
+
+ # Replace InverseKinematicsProcessor with new kinematic processors
+ if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
+ # Add EE bounds and safety processor
+ inverse_kinematics_steps = [
+ MapTensorToDeltaActionDictStep(
+ use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False
+ ),
+ MapDeltaActionToRobotActionStep(),
+ EEReferenceAndDelta(
+ kinematics=kinematics_solver,
+ end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes,
+ motor_names=motor_names,
+ use_latched_reference=False,
+ use_ik_solution=True,
+ ),
+ EEBoundsAndSafety(
+ end_effector_bounds=cfg.processor.inverse_kinematics.end_effector_bounds,
+ ),
+ GripperVelocityToJoint(
+ clip_max=cfg.processor.max_gripper_pos,
+ speed_factor=1.0,
+ discrete_gripper=True,
+ ),
+ InverseKinematicsRLStep(
+ kinematics=kinematics_solver, motor_names=motor_names, initial_guess_current_joints=False
+ ),
+ ]
+ action_pipeline_steps.extend(inverse_kinematics_steps)
+ action_pipeline_steps.append(RobotActionToPolicyActionProcessorStep(motor_names=motor_names))
+
+ return DataProcessorPipeline(
+ steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
+ ), DataProcessorPipeline(
+ steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
+ )
+
+
+def step_env_and_process_transition(
+ env: gym.Env,
+ transition: EnvTransition,
+ action: torch.Tensor,
+ env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
+ action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
+) -> EnvTransition:
+ """
+ Execute one step with processor pipeline.
+
+ Args:
+ env: The robot environment
+ transition: Current transition state
+ action: Action to execute
+ env_processor: Environment processor
+ action_processor: Action processor
+
+ Returns:
+ Processed transition with updated state.
+ """
+
+ # Create action transition
+ transition[TransitionKey.ACTION] = action
+ transition[TransitionKey.OBSERVATION] = (
+ env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {}
+ )
+ processed_action_transition = action_processor(transition)
+ processed_action = processed_action_transition[TransitionKey.ACTION]
+
+ obs, reward, terminated, truncated, info = env.step(processed_action)
+
+ reward = reward + processed_action_transition[TransitionKey.REWARD]
+ terminated = terminated or processed_action_transition[TransitionKey.DONE]
+ truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
+ complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
+ new_info = processed_action_transition[TransitionKey.INFO].copy()
+ new_info.update(info)
+
+ new_transition = create_transition(
+ observation=obs,
+ action=processed_action,
+ reward=reward,
+ done=terminated,
+ truncated=truncated,
+ info=new_info,
+ complementary_data=complementary_data,
+ )
+ new_transition = env_processor(new_transition)
+
+ return new_transition
+
+
+def control_loop(
+ env: gym.Env,
+ env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
+ action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
+ teleop_device: Teleoperator,
+ cfg: GymManipulatorConfig,
+) -> None:
+ """Main control loop for robot environment interaction.
+ if cfg.mode == "record": then a dataset will be created and recorded
+
+ Args:
+ env: The robot environment
+ env_processor: Environment processor
+ action_processor: Action processor
+ teleop_device: Teleoperator device
+ cfg: gym_manipulator configuration
+ """
+ dt = 1.0 / cfg.env.fps
+
+ print(f"Starting control loop at {cfg.env.fps} FPS")
+ print("Controls:")
+ print("- Use gamepad/teleop device for intervention")
+ print("- When not intervening, robot will stay still")
+ print("- Press Ctrl+C to exit")
+
+ # Reset environment and processors
+ obs, info = env.reset()
+ complementary_data = (
+ {"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
+ )
+ env_processor.reset()
+ action_processor.reset()
+
+ # Process initial observation
+ transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
+ transition = env_processor(data=transition)
+
+ # Determine if gripper is used
+ use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
+
+ dataset = None
+ if cfg.mode == "record":
+ action_features = teleop_device.action_features
+ features = {
+ ACTION: action_features,
+ REWARD: {"dtype": "float32", "shape": (1,), "names": None},
+ DONE: {"dtype": "bool", "shape": (1,), "names": None},
+ }
+ if use_gripper:
+ features["complementary_info.discrete_penalty"] = {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": ["discrete_penalty"],
+ }
+
+ for key, value in transition[TransitionKey.OBSERVATION].items():
+ if key == OBS_STATE:
+ features[key] = {
+ "dtype": "float32",
+ "shape": value.squeeze(0).shape,
+ "names": None,
+ }
+ if "image" in key:
+ features[key] = {
+ "dtype": "video",
+ "shape": value.squeeze(0).shape,
+ "names": ["channels", "height", "width"],
+ }
+
+ # Create dataset
+ dataset = LeRobotDataset.create(
+ cfg.dataset.repo_id,
+ cfg.env.fps,
+ root=cfg.dataset.root,
+ use_videos=True,
+ image_writer_threads=4,
+ image_writer_processes=0,
+ features=features,
+ )
+
+ episode_idx = 0
+ episode_step = 0
+ episode_start_time = time.perf_counter()
+
+ while episode_idx < cfg.dataset.num_episodes_to_record:
+ step_start_time = time.perf_counter()
+
+ # Create a neutral action (no movement)
+ neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
+ if use_gripper:
+ neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
+
+ # Use the new step function
+ transition = step_env_and_process_transition(
+ env=env,
+ transition=transition,
+ action=neutral_action,
+ env_processor=env_processor,
+ action_processor=action_processor,
+ )
+ terminated = transition.get(TransitionKey.DONE, False)
+ truncated = transition.get(TransitionKey.TRUNCATED, False)
+
+ if cfg.mode == "record":
+ observations = {
+ k: v.squeeze(0).cpu()
+ for k, v in transition[TransitionKey.OBSERVATION].items()
+ if isinstance(v, torch.Tensor)
+ }
+ # Use teleop_action if available, otherwise use the action from the transition
+ action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
+ "teleop_action", transition[TransitionKey.ACTION]
+ )
+ frame = {
+ **observations,
+ ACTION: action_to_record.cpu(),
+ REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
+ DONE: np.array([terminated or truncated], dtype=bool),
+ }
+ if use_gripper:
+ discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
+ frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
+
+ if dataset is not None:
+ frame["task"] = cfg.dataset.task
+ dataset.add_frame(frame)
+
+ episode_step += 1
+
+ # Handle episode termination
+ if terminated or truncated:
+ episode_time = time.perf_counter() - episode_start_time
+ logging.info(
+ f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
+ )
+ episode_step = 0
+ episode_idx += 1
+
+ if dataset is not None:
+ if transition[TransitionKey.INFO].get("rerecord_episode", False):
+ logging.info(f"Re-recording episode {episode_idx}")
+ dataset.clear_episode_buffer()
+ episode_idx -= 1
+ else:
+ logging.info(f"Saving episode {episode_idx}")
+ dataset.save_episode()
+
+ # Reset for new episode
+ obs, info = env.reset()
+ env_processor.reset()
+ action_processor.reset()
+
+ transition = create_transition(observation=obs, info=info)
+ transition = env_processor(transition)
+
+ # Maintain fps timing
+ busy_wait(dt - (time.perf_counter() - step_start_time))
+
+ if dataset is not None and cfg.dataset.push_to_hub:
+ logging.info("Pushing dataset to hub")
+ dataset.push_to_hub()
+
+
+def replay_trajectory(
+ env: gym.Env, action_processor: DataProcessorPipeline, cfg: GymManipulatorConfig
+) -> None:
+ """Replay recorded trajectory on robot environment."""
+ assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay"
+
+ dataset = LeRobotDataset(
+ cfg.dataset.repo_id,
+ root=cfg.dataset.root,
+ episodes=[cfg.dataset.replay_episode],
+ download_videos=False,
+ )
+ episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
+ actions = episode_frames.select_columns(ACTION)
+
+ _, info = env.reset()
+
+ for action_data in actions:
+ start_time = time.perf_counter()
+ transition = create_transition(
+ observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {},
+ action=action_data[ACTION],
+ )
+ transition = action_processor(transition)
+ env.step(transition[TransitionKey.ACTION])
+ busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
+
+
+@parser.wrap()
+def main(cfg: GymManipulatorConfig) -> None:
+ """Main entry point for gym manipulator script."""
+ env, teleop_device = make_robot_env(cfg.env)
+ env_processor, action_processor = make_processors(env, teleop_device, cfg.env, cfg.device)
+
+ print("Environment observation space:", env.observation_space)
+ print("Environment action space:", env.action_space)
+ print("Environment processor:", env_processor)
+ print("Action processor:", action_processor)
+
+ if cfg.mode == "replay":
+ replay_trajectory(env, action_processor, cfg)
+ exit()
+
+ control_loop(env, env_processor, action_processor, teleop_device, cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/rl/learner.py
similarity index 95%
rename from src/lerobot/scripts/rl/learner.py
rename to src/lerobot/rl/learner.py
index d8830d83e..d9758d3a3 100644
--- a/src/lerobot/scripts/rl/learner.py
+++ b/src/lerobot/rl/learner.py
@@ -25,7 +25,7 @@ Examples of usage:
- Start a learner server for training:
```bash
-python -m lerobot.scripts.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json
+python -m lerobot.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json
```
**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server
@@ -62,50 +62,45 @@ from torch.optim.optimizer import Optimizer
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
-from lerobot.constants import (
+from lerobot.datasets.factory import make_dataset
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.policies.factory import make_policy
+from lerobot.policies.sac.modeling_sac import SACPolicy
+from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
+from lerobot.rl.process import ProcessSignalHandler
+from lerobot.rl.wandb_utils import WandBLogger
+from lerobot.robots import so100_follower # noqa: F401
+from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
+from lerobot.teleoperators.utils import TeleopEvents
+from lerobot.transport import services_pb2_grpc
+from lerobot.transport.utils import (
+ MAX_MESSAGE_SIZE,
+ bytes_to_python_object,
+ bytes_to_transitions,
+ state_to_bytes,
+)
+from lerobot.utils.constants import (
+ ACTION,
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
)
-from lerobot.datasets.factory import make_dataset
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.factory import make_policy
-from lerobot.policies.sac.modeling_sac import SACPolicy
-from lerobot.robots import so100_follower # noqa: F401
-from lerobot.scripts.rl import learner_service
-from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
-from lerobot.transport import services_pb2_grpc
-from lerobot.transport.utils import (
- bytes_to_python_object,
- bytes_to_transitions,
- state_to_bytes,
-)
-from lerobot.utils.buffer import ReplayBuffer, concatenate_batch_transitions
-from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
+ load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
)
-from lerobot.utils.train_utils import (
- load_training_state as utils_load_training_state,
-)
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
get_safe_torch_device,
init_logging,
)
-from lerobot.utils.wandb_utils import WandBLogger
-LOG_PREFIX = "[LEARNER]"
-
-
-#################################################
-# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
-#################################################
+from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
@parser.wrap()
@@ -157,7 +152,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
# Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project:
- from lerobot.utils.wandb_utils import WandBLogger
+ from lerobot.rl.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
@@ -250,9 +245,7 @@ def start_learner_threads(
logging.info("[LEARNER] queues closed")
-#################################################
-# Core algorithm functions #
-#################################################
+# Core algorithm functions
def add_actor_information_and_train(
@@ -408,7 +401,7 @@ def add_actor_information_and_train(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
- actions = batch["action"]
+ actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
@@ -421,7 +414,7 @@ def add_actor_information_and_train(
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
- "action": actions,
+ ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
@@ -466,7 +459,7 @@ def add_actor_information_and_train(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
- actions = batch["action"]
+ actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
@@ -480,7 +473,7 @@ def add_actor_information_and_train(
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
- "action": actions,
+ ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
@@ -646,7 +639,7 @@ def start_learner(
# TODO: Check if its useful
_ = ProcessSignalHandler(False, display_pid=True)
- service = learner_service.LearnerService(
+ service = LearnerService(
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency,
@@ -656,10 +649,10 @@ def start_learner(
)
server = grpc.server(
- ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
+ ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
- ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
- ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
+ ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
+ ("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
],
)
@@ -677,7 +670,7 @@ def start_learner(
shutdown_event.wait()
logging.info("[LEARNER] Stopping gRPC server...")
- server.stop(learner_service.SHUTDOWN_TIMEOUT)
+ server.stop(SHUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
@@ -820,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M
return optimizers, lr_scheduler
-#################################################
-# Training setup functions #
-#################################################
+# Training setup functions
def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig:
@@ -1023,9 +1014,7 @@ def initialize_offline_replay_buffer(
return offline_replay_buffer
-#################################################
-# Utilities/Helpers functions #
-#################################################
+# Utilities/Helpers functions
def get_observation_features(
@@ -1049,10 +1038,8 @@ def get_observation_features(
return None, None
with torch.no_grad():
- observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
- next_observation_features = policy.actor.encoder.get_cached_image_features(
- next_observations, normalize=True
- )
+ observation_features = policy.actor.encoder.get_cached_image_features(observations)
+ next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
@@ -1109,8 +1096,18 @@ def check_nan_in_transition(
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
- state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
- state_bytes = state_to_bytes(state_dict)
+
+ # Create a dictionary to hold all the state dicts
+ state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
+
+ # Add discrete critic if it exists
+ if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
+ state_dicts["discrete_critic"] = move_state_dict_to_device(
+ policy.discrete_critic.state_dict(), device="cpu"
+ )
+ logging.debug("[LEARNER] Including discrete critic in state dict push")
+
+ state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
@@ -1157,7 +1154,7 @@ def process_transitions(
# Skip transitions with NaN values
if check_nan_in_transition(
observations=transition["state"],
- actions=transition["action"],
+ actions=transition[ACTION],
next_state=transition["next_state"],
):
logging.warning("[LEARNER] NaN detected in transition, skipping")
@@ -1167,7 +1164,7 @@ def process_transitions(
# Add to offline buffer if it's an intervention
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
- "is_intervention"
+ TeleopEvents.IS_INTERVENTION
):
offline_replay_buffer.add(**transition)
diff --git a/src/lerobot/scripts/rl/learner_service.py b/src/lerobot/rl/learner_service.py
similarity index 97%
rename from src/lerobot/scripts/rl/learner_service.py
rename to src/lerobot/rl/learner_service.py
index 198e52945..7ef38119b 100644
--- a/src/lerobot/scripts/rl/learner_service.py
+++ b/src/lerobot/rl/learner_service.py
@@ -19,11 +19,10 @@ import logging
import time
from multiprocessing import Event, Queue
+from lerobot.rl.queue import get_last_item_from_queue
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
-from lerobot.utils.queue import get_last_item_from_queue
-MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
diff --git a/src/lerobot/utils/process.py b/src/lerobot/rl/process.py
similarity index 100%
rename from src/lerobot/utils/process.py
rename to src/lerobot/rl/process.py
diff --git a/src/lerobot/utils/queue.py b/src/lerobot/rl/queue.py
similarity index 64%
rename from src/lerobot/utils/queue.py
rename to src/lerobot/rl/queue.py
index ceb30e2bf..864d798ac 100644
--- a/src/lerobot/utils/queue.py
+++ b/src/lerobot/rl/queue.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import platform
+from contextlib import suppress
from queue import Empty
from typing import Any
@@ -30,10 +32,21 @@ def get_last_item_from_queue(queue: Queue, block=True, timeout: float = 0.1) ->
item = None
# Drain queue and keep only the most recent parameters
- try:
- while True:
+ if platform.system() == "Darwin":
+ # On Mac, avoid using `qsize` due to unreliable implementation.
+ # There is a comment on `qsize` code in the Python source:
+ # Raises NotImplementedError on Mac OSX because of broken sem_getvalue()
+ try:
+ while True:
+ item = queue.get_nowait()
+ except Empty:
+ pass
+
+ return item
+
+ # Details about using qsize in https://github.com/huggingface/lerobot/issues/1523
+ while queue.qsize() > 0:
+ with suppress(Empty):
item = queue.get_nowait()
- except Empty:
- pass
return item
diff --git a/src/lerobot/utils/wandb_utils.py b/src/lerobot/rl/wandb_utils.py
similarity index 98%
rename from src/lerobot/utils/wandb_utils.py
rename to src/lerobot/rl/wandb_utils.py
index 91b4ec95c..01cef9487 100644
--- a/src/lerobot/utils/wandb_utils.py
+++ b/src/lerobot/rl/wandb_utils.py
@@ -23,7 +23,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from termcolor import colored
from lerobot.configs.train import TrainPipelineConfig
-from lerobot.constants import PRETRAINED_MODEL_DIR
+from lerobot.utils.constants import PRETRAINED_MODEL_DIR
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
@@ -137,7 +137,7 @@ class WandBLogger:
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
- if not isinstance(v, (int, float, str)):
+ if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
diff --git a/src/lerobot/robots/__init__.py b/src/lerobot/robots/__init__.py
index d8fd0de93..1dba0f1b0 100644
--- a/src/lerobot/robots/__init__.py
+++ b/src/lerobot/robots/__init__.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config import RobotConfig
from .robot import Robot
from .utils import make_robot_from_config
diff --git a/src/lerobot/robots/bi_so100_follower/__init__.py b/src/lerobot/robots/bi_so100_follower/__init__.py
new file mode 100644
index 000000000..90f56516b
--- /dev/null
+++ b/src/lerobot/robots/bi_so100_follower/__init__.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .bi_so100_follower import BiSO100Follower
+from .config_bi_so100_follower import BiSO100FollowerConfig
diff --git a/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py b/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py
new file mode 100644
index 000000000..7992b79fd
--- /dev/null
+++ b/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py
@@ -0,0 +1,163 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+from functools import cached_property
+from typing import Any
+
+from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.robots.so100_follower import SO100Follower
+from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
+
+from ..robot import Robot
+from .config_bi_so100_follower import BiSO100FollowerConfig
+
+logger = logging.getLogger(__name__)
+
+
+class BiSO100Follower(Robot):
+ """
+ [Bimanual SO-100 Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
+ This bimanual robot can also be easily adapted to use SO-101 follower arms, just replace the SO100Follower class with SO101Follower and SO100FollowerConfig with SO101FollowerConfig.
+ """
+
+ config_class = BiSO100FollowerConfig
+ name = "bi_so100_follower"
+
+ def __init__(self, config: BiSO100FollowerConfig):
+ super().__init__(config)
+ self.config = config
+
+ left_arm_config = SO100FollowerConfig(
+ id=f"{config.id}_left" if config.id else None,
+ calibration_dir=config.calibration_dir,
+ port=config.left_arm_port,
+ disable_torque_on_disconnect=config.left_arm_disable_torque_on_disconnect,
+ max_relative_target=config.left_arm_max_relative_target,
+ use_degrees=config.left_arm_use_degrees,
+ cameras={},
+ )
+
+ right_arm_config = SO100FollowerConfig(
+ id=f"{config.id}_right" if config.id else None,
+ calibration_dir=config.calibration_dir,
+ port=config.right_arm_port,
+ disable_torque_on_disconnect=config.right_arm_disable_torque_on_disconnect,
+ max_relative_target=config.right_arm_max_relative_target,
+ use_degrees=config.right_arm_use_degrees,
+ cameras={},
+ )
+
+ self.left_arm = SO100Follower(left_arm_config)
+ self.right_arm = SO100Follower(right_arm_config)
+ self.cameras = make_cameras_from_configs(config.cameras)
+
+ @property
+ def _motors_ft(self) -> dict[str, type]:
+ return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | {
+ f"right_{motor}.pos": float for motor in self.right_arm.bus.motors
+ }
+
+ @property
+ def _cameras_ft(self) -> dict[str, tuple]:
+ return {
+ cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
+ }
+
+ @cached_property
+ def observation_features(self) -> dict[str, type | tuple]:
+ return {**self._motors_ft, **self._cameras_ft}
+
+ @cached_property
+ def action_features(self) -> dict[str, type]:
+ return self._motors_ft
+
+ @property
+ def is_connected(self) -> bool:
+ return (
+ self.left_arm.bus.is_connected
+ and self.right_arm.bus.is_connected
+ and all(cam.is_connected for cam in self.cameras.values())
+ )
+
+ def connect(self, calibrate: bool = True) -> None:
+ self.left_arm.connect(calibrate)
+ self.right_arm.connect(calibrate)
+
+ for cam in self.cameras.values():
+ cam.connect()
+
+ @property
+ def is_calibrated(self) -> bool:
+ return self.left_arm.is_calibrated and self.right_arm.is_calibrated
+
+ def calibrate(self) -> None:
+ self.left_arm.calibrate()
+ self.right_arm.calibrate()
+
+ def configure(self) -> None:
+ self.left_arm.configure()
+ self.right_arm.configure()
+
+ def setup_motors(self) -> None:
+ self.left_arm.setup_motors()
+ self.right_arm.setup_motors()
+
+ def get_observation(self) -> dict[str, Any]:
+ obs_dict = {}
+
+ # Add "left_" prefix
+ left_obs = self.left_arm.get_observation()
+ obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
+
+ # Add "right_" prefix
+ right_obs = self.right_arm.get_observation()
+ obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
+
+ for cam_key, cam in self.cameras.items():
+ start = time.perf_counter()
+ obs_dict[cam_key] = cam.async_read()
+ dt_ms = (time.perf_counter() - start) * 1e3
+ logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
+
+ return obs_dict
+
+ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
+ # Remove "left_" prefix
+ left_action = {
+ key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
+ }
+ # Remove "right_" prefix
+ right_action = {
+ key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
+ }
+
+ send_action_left = self.left_arm.send_action(left_action)
+ send_action_right = self.right_arm.send_action(right_action)
+
+ # Add prefixes back
+ prefixed_send_action_left = {f"left_{key}": value for key, value in send_action_left.items()}
+ prefixed_send_action_right = {f"right_{key}": value for key, value in send_action_right.items()}
+
+ return {**prefixed_send_action_left, **prefixed_send_action_right}
+
+ def disconnect(self):
+ self.left_arm.disconnect()
+ self.right_arm.disconnect()
+
+ for cam in self.cameras.values():
+ cam.disconnect()
diff --git a/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py b/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py
new file mode 100644
index 000000000..5806d7415
--- /dev/null
+++ b/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+from lerobot.cameras import CameraConfig
+
+from ..config import RobotConfig
+
+
+@RobotConfig.register_subclass("bi_so100_follower")
+@dataclass
+class BiSO100FollowerConfig(RobotConfig):
+ left_arm_port: str
+ right_arm_port: str
+
+ # Optional
+ left_arm_disable_torque_on_disconnect: bool = True
+ left_arm_max_relative_target: float | dict[str, float] | None = None
+ left_arm_use_degrees: bool = False
+ right_arm_disable_torque_on_disconnect: bool = True
+ right_arm_max_relative_target: float | dict[str, float] | None = None
+ right_arm_use_degrees: bool = False
+
+ # cameras (shared between both arms)
+ cameras: dict[str, CameraConfig] = field(default_factory=dict)
diff --git a/src/lerobot/robots/hope_jr/__init__.py b/src/lerobot/robots/hope_jr/__init__.py
new file mode 100644
index 000000000..26603ebb0
--- /dev/null
+++ b/src/lerobot/robots/hope_jr/__init__.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig
+from .hope_jr_arm import HopeJrArm
+from .hope_jr_hand import HopeJrHand
diff --git a/src/lerobot/robots/hope_jr/config_hope_jr.py b/src/lerobot/robots/hope_jr/config_hope_jr.py
new file mode 100644
index 000000000..f2af5f47c
--- /dev/null
+++ b/src/lerobot/robots/hope_jr/config_hope_jr.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+from lerobot.cameras import CameraConfig
+
+from ..config import RobotConfig
+
+
+@RobotConfig.register_subclass("hope_jr_hand")
+@dataclass
+class HopeJrHandConfig(RobotConfig):
+ port: str # Port to connect to the hand
+ side: str # "left" / "right"
+
+ disable_torque_on_disconnect: bool = True
+
+ cameras: dict[str, CameraConfig] = field(default_factory=dict)
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.side not in ["right", "left"]:
+ raise ValueError(self.side)
+
+
+@RobotConfig.register_subclass("hope_jr_arm")
+@dataclass
+class HopeJrArmConfig(RobotConfig):
+ port: str # Port to connect to the hand
+ disable_torque_on_disconnect: bool = True
+
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
+ # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
+ # names to the max_relative_target value for that motor.
+ max_relative_target: float | dict[str, float] | None = None
+
+ cameras: dict[str, CameraConfig] = field(default_factory=dict)
diff --git a/src/lerobot/robots/hope_jr/hope_jr.mdx b/src/lerobot/robots/hope_jr/hope_jr.mdx
new file mode 120000
index 000000000..a076e4754
--- /dev/null
+++ b/src/lerobot/robots/hope_jr/hope_jr.mdx
@@ -0,0 +1 @@
+../../../../docs/source/hope_jr.mdx
\ No newline at end of file
diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py
new file mode 100644
index 000000000..220a29f8c
--- /dev/null
+++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py
@@ -0,0 +1,176 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+from functools import cached_property
+from typing import Any
+
+from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.motors import Motor, MotorNormMode
+from lerobot.motors.calibration_gui import RangeFinderGUI
+from lerobot.motors.feetech import (
+ FeetechMotorsBus,
+)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+
+from ..robot import Robot
+from ..utils import ensure_safe_goal_position
+from .config_hope_jr import HopeJrArmConfig
+
+logger = logging.getLogger(__name__)
+
+
+class HopeJrArm(Robot):
+ config_class = HopeJrArmConfig
+ name = "hope_jr_arm"
+
+ def __init__(self, config: HopeJrArmConfig):
+ super().__init__(config)
+ self.config = config
+ self.bus = FeetechMotorsBus(
+ port=self.config.port,
+ motors={
+ "shoulder_pitch": Motor(1, "sm8512bl", MotorNormMode.RANGE_M100_100),
+ "shoulder_yaw": Motor(2, "sts3250", MotorNormMode.RANGE_M100_100),
+ "shoulder_roll": Motor(3, "sts3250", MotorNormMode.RANGE_M100_100),
+ "elbow_flex": Motor(4, "sts3250", MotorNormMode.RANGE_M100_100),
+ "wrist_roll": Motor(5, "sts3250", MotorNormMode.RANGE_M100_100),
+ "wrist_yaw": Motor(6, "sts3250", MotorNormMode.RANGE_M100_100),
+ "wrist_pitch": Motor(7, "sts3250", MotorNormMode.RANGE_M100_100),
+ },
+ calibration=self.calibration,
+ )
+ self.cameras = make_cameras_from_configs(config.cameras)
+
+ # HACK
+ self.shoulder_pitch = "shoulder_pitch"
+ self.other_motors = [m for m in self.bus.motors if m != "shoulder_pitch"]
+
+ @property
+ def _motors_ft(self) -> dict[str, type]:
+ return {f"{motor}.pos": float for motor in self.bus.motors}
+
+ @property
+ def _cameras_ft(self) -> dict[str, tuple]:
+ return {
+ cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
+ }
+
+ @cached_property
+ def observation_features(self) -> dict[str, type | tuple]:
+ return {**self._motors_ft, **self._cameras_ft}
+
+ @cached_property
+ def action_features(self) -> dict[str, type]:
+ return self._motors_ft
+
+ @property
+ def is_connected(self) -> bool:
+ return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
+
+ def connect(self, calibrate: bool = True) -> None:
+ """
+ We assume that at connection time, arm is in a rest position,
+ and torque can be safely disabled to run calibration.
+ """
+ if self.is_connected:
+ raise DeviceAlreadyConnectedError(f"{self} already connected")
+
+ self.bus.connect(handshake=False)
+ if not self.is_calibrated and calibrate:
+ self.calibrate()
+
+ # Connect the cameras
+ for cam in self.cameras.values():
+ cam.connect()
+
+ self.configure()
+ logger.info(f"{self} connected.")
+
+ @property
+ def is_calibrated(self) -> bool:
+ return self.bus.is_calibrated
+
+ def calibrate(self) -> None:
+ groups = {
+ "all": list(self.bus.motors.keys()),
+ "shoulder": ["shoulder_pitch", "shoulder_yaw", "shoulder_roll"],
+ "elbow": ["elbow_flex"],
+ "wrist": ["wrist_roll", "wrist_yaw", "wrist_pitch"],
+ }
+
+ self.calibration = RangeFinderGUI(self.bus, groups).run()
+ self._save_calibration()
+ print("Calibration saved to", self.calibration_fpath)
+
+ def configure(self) -> None:
+ with self.bus.torque_disabled():
+ self.bus.configure_motors(maximum_acceleration=30, acceleration=30)
+
+ def setup_motors(self) -> None:
+ # TODO: add docstring
+ for motor in reversed(self.bus.motors):
+ input(f"Connect the controller board to the '{motor}' motor only and press enter.")
+ self.bus.setup_motor(motor)
+ print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
+
+ def get_observation(self) -> dict[str, Any]:
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ # Read arm position
+ start = time.perf_counter()
+ obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
+ obs_dict[self.shoulder_pitch] = self.bus.read("Present_Position", self.shoulder_pitch)
+ obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
+ dt_ms = (time.perf_counter() - start) * 1e3
+ logger.debug(f"{self} read state: {dt_ms:.1f}ms")
+
+ # Capture images from cameras
+ for cam_key, cam in self.cameras.items():
+ start = time.perf_counter()
+ obs_dict[cam_key] = cam.async_read()
+ dt_ms = (time.perf_counter() - start) * 1e3
+ logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
+
+ return obs_dict
+
+ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
+
+ # Cap goal position when too far away from present position.
+ # /!\ Slower fps expected due to reading from the follower.
+ if self.config.max_relative_target is not None:
+ present_pos = self.bus.sync_read("Present_Position")
+ goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
+ goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
+
+ self.bus.sync_write("Goal_Position", goal_pos)
+ return {f"{motor}.pos": val for motor, val in goal_pos.items()}
+
+ def disconnect(self):
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ self.bus.disconnect(self.config.disable_torque_on_disconnect)
+ for cam in self.cameras.values():
+ cam.disconnect()
+
+ logger.info(f"{self} disconnected.")
diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py
new file mode 100644
index 000000000..9e960642b
--- /dev/null
+++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py
@@ -0,0 +1,200 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+from functools import cached_property
+from typing import Any
+
+from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.motors import Motor, MotorNormMode
+from lerobot.motors.calibration_gui import RangeFinderGUI
+from lerobot.motors.feetech import (
+ FeetechMotorsBus,
+)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+
+from ..robot import Robot
+from .config_hope_jr import HopeJrHandConfig
+
+logger = logging.getLogger(__name__)
+
+RIGHT_HAND_INVERSIONS = [
+ "thumb_mcp",
+ "thumb_dip",
+ "index_ulnar_flexor",
+ "middle_ulnar_flexor",
+ "ring_ulnar_flexor",
+ "ring_pip_dip",
+ "pinky_ulnar_flexor",
+ "pinky_pip_dip",
+]
+
+LEFT_HAND_INVERSIONS = [
+ "thumb_cmc",
+ "thumb_mcp",
+ "thumb_dip",
+ "index_radial_flexor",
+ "index_pip_dip",
+ "middle_radial_flexor",
+ "middle_pip_dip",
+ "ring_radial_flexor",
+ "ring_pip_dip",
+ "pinky_radial_flexor",
+ # "pinky_pip_dip",
+]
+
+
+class HopeJrHand(Robot):
+ config_class = HopeJrHandConfig
+ name = "hope_jr_hand"
+
+ def __init__(self, config: HopeJrHandConfig):
+ super().__init__(config)
+ self.config = config
+ self.bus = FeetechMotorsBus(
+ port=self.config.port,
+ motors={
+ # Thumb
+ "thumb_cmc": Motor(1, "scs0009", MotorNormMode.RANGE_0_100),
+ "thumb_mcp": Motor(2, "scs0009", MotorNormMode.RANGE_0_100),
+ "thumb_pip": Motor(3, "scs0009", MotorNormMode.RANGE_0_100),
+ "thumb_dip": Motor(4, "scs0009", MotorNormMode.RANGE_0_100),
+ # Index
+ "index_radial_flexor": Motor(5, "scs0009", MotorNormMode.RANGE_0_100),
+ "index_ulnar_flexor": Motor(6, "scs0009", MotorNormMode.RANGE_0_100),
+ "index_pip_dip": Motor(7, "scs0009", MotorNormMode.RANGE_0_100),
+ # Middle
+ "middle_radial_flexor": Motor(8, "scs0009", MotorNormMode.RANGE_0_100),
+ "middle_ulnar_flexor": Motor(9, "scs0009", MotorNormMode.RANGE_0_100),
+ "middle_pip_dip": Motor(10, "scs0009", MotorNormMode.RANGE_0_100),
+ # Ring
+ "ring_radial_flexor": Motor(11, "scs0009", MotorNormMode.RANGE_0_100),
+ "ring_ulnar_flexor": Motor(12, "scs0009", MotorNormMode.RANGE_0_100),
+ "ring_pip_dip": Motor(13, "scs0009", MotorNormMode.RANGE_0_100),
+ # Pinky
+ "pinky_radial_flexor": Motor(14, "scs0009", MotorNormMode.RANGE_0_100),
+ "pinky_ulnar_flexor": Motor(15, "scs0009", MotorNormMode.RANGE_0_100),
+ "pinky_pip_dip": Motor(16, "scs0009", MotorNormMode.RANGE_0_100),
+ },
+ calibration=self.calibration,
+ protocol_version=1,
+ )
+ self.cameras = make_cameras_from_configs(config.cameras)
+ self.inverted_motors = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS
+
+ @property
+ def _motors_ft(self) -> dict[str, type]:
+ return {f"{motor}.pos": float for motor in self.bus.motors}
+
+ @property
+ def _cameras_ft(self) -> dict[str, tuple]:
+ return {
+ cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
+ }
+
+ @cached_property
+ def observation_features(self) -> dict[str, type | tuple]:
+ return {**self._motors_ft, **self._cameras_ft}
+
+ @cached_property
+ def action_features(self) -> dict[str, type]:
+ return self._motors_ft
+
+ @property
+ def is_connected(self) -> bool:
+ return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
+
+ def connect(self, calibrate: bool = True) -> None:
+ if self.is_connected:
+ raise DeviceAlreadyConnectedError(f"{self} already connected")
+
+ self.bus.connect()
+ if not self.is_calibrated and calibrate:
+ self.calibrate()
+
+ # Connect the cameras
+ for cam in self.cameras.values():
+ cam.connect()
+
+ self.configure()
+ logger.info(f"{self} connected.")
+
+ @property
+ def is_calibrated(self) -> bool:
+ return self.bus.is_calibrated
+
+ def calibrate(self) -> None:
+ fingers = {}
+ for finger in ["thumb", "index", "middle", "ring", "pinky"]:
+ fingers[finger] = [motor for motor in self.bus.motors if motor.startswith(finger)]
+
+ self.calibration = RangeFinderGUI(self.bus, fingers).run()
+ for motor in self.inverted_motors:
+ self.calibration[motor].drive_mode = 1
+ self._save_calibration()
+ print("Calibration saved to", self.calibration_fpath)
+
+ def configure(self) -> None:
+ with self.bus.torque_disabled():
+ self.bus.configure_motors()
+
+ def setup_motors(self) -> None:
+ # TODO: add docstring
+ for motor in self.bus.motors:
+ input(f"Connect the controller board to the '{motor}' motor only and press enter.")
+ self.bus.setup_motor(motor)
+ print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
+
+ def get_observation(self) -> dict[str, Any]:
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ obs_dict = {}
+
+ # Read hand position
+ start = time.perf_counter()
+ for motor in self.bus.motors:
+ obs_dict[f"{motor}.pos"] = self.bus.read("Present_Position", motor)
+ dt_ms = (time.perf_counter() - start) * 1e3
+ logger.debug(f"{self} read state: {dt_ms:.1f}ms")
+
+ # Capture images from cameras
+ for cam_key, cam in self.cameras.items():
+ start = time.perf_counter()
+ obs_dict[cam_key] = cam.async_read()
+ dt_ms = (time.perf_counter() - start) * 1e3
+ logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
+
+ return obs_dict
+
+ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
+ self.bus.sync_write("Goal_Position", goal_pos)
+ return action
+
+ def disconnect(self):
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ self.bus.disconnect(self.config.disable_torque_on_disconnect)
+ for cam in self.cameras.values():
+ cam.disconnect()
+
+ logger.info(f"{self} disconnected.")
diff --git a/src/lerobot/robots/koch_follower/__init__.py b/src/lerobot/robots/koch_follower/__init__.py
index ae98a2c38..6271c4e55 100644
--- a/src/lerobot/robots/koch_follower/__init__.py
+++ b/src/lerobot/robots/koch_follower/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_koch_follower import KochFollowerConfig
from .koch_follower import KochFollower
diff --git a/src/lerobot/robots/koch_follower/config_koch_follower.py b/src/lerobot/robots/koch_follower/config_koch_follower.py
index a7c9249ae..02a95ef4e 100644
--- a/src/lerobot/robots/koch_follower/config_koch_follower.py
+++ b/src/lerobot/robots/koch_follower/config_koch_follower.py
@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
- # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
- # the number of motors in your follower arms.
- max_relative_target: int | None = None
+ # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
+ # names to the max_relative_target value for that motor.
+ max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
diff --git a/src/lerobot/robots/koch_follower/koch.mdx b/src/lerobot/robots/koch_follower/koch.mdx
deleted file mode 100644
index f70a1802c..000000000
--- a/src/lerobot/robots/koch_follower/koch.mdx
+++ /dev/null
@@ -1,258 +0,0 @@
-# Koch v1.1
-
-In the steps below, we explain how to assemble the Koch v1.1 robot.
-
-## Order and assemble the parts
-
-Follow the sourcing and assembling instructions provided in this [README](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below.
-
-For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk).
-
-> [!WARNING]
-> Since the production of this video, we simplified the configuration phase. Because of this, two things differ from the instructions in that video:
-> - Don't plug in all the motor cables right away and wait to be instructed to do so in [Configure the motors](#configure-the-motors).
-> - Don't screw in the controller board (PCB) to the base right away and wait for being instructed to do so in [Configure the motors](#configure-the-motors).
-
-
-## Install LeRobot 🤗
-
-To install LeRobot follow, our [Installation Guide](./installation)
-
-In addition to these instructions, you need to install the Dynamixel SDK:
-```bash
-pip install -e ".[dynamixel]"
-```
-
-## Configure the motors
-
-### 1. Find the USB ports associated with each arm
-
-To find the port for each bus servo adapter, run this script:
-```bash
-python -m lerobot.find_port
-```
-
-
-
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
-Remove the USB cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/tty.usbmodem575E0032081
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
-
-
-
-
-On Linux, you might need to give access to the USB ports by running:
-```bash
-sudo chmod 666 /dev/ttyACM0
-sudo chmod 666 /dev/ttyACM1
-```
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/ttyACM0', '/dev/ttyACM1']
-Remove the usb cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/ttyACM1
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm.
-
-
-
-
-### 2. Set the motors ids and baudrates
-
-Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate.
-
-To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once.
-
-If you are repurposing motors from another robot, you will probably also need to perform this step, as the ids and baudrate likely won't match.
-
-#### Follower
-
-Connect the usb cable from your computer and the 5V power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter.
-
-For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm.
-
-
-
-
-```bash
-python -m lerobot.setup_motors \
- --robot.type=koch_follower \
- --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
-```
-
-
-
-```python
-from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
-
-config = KochFollowerConfig(
- port="/dev/tty.usbmodem575E0031751",
- id="my_awesome_follower_arm",
-)
-follower = KochFollower(config)
-follower.setup_motors()
-```
-
-
-
-You should see the following instruction.
-```
-Connect the controller board to the 'gripper' motor only and press enter.
-```
-
-As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor.
-
-
-Troubleshooting
-
- If you get an error at that point, check your cables and make sure they are plugged in properly:
-
-
Power supply
-
USB cable between your computer and the controller board
-
The 3-pin cable from the controller board to the motor
-
-
- If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
-
-
-You should then see the following message:
-```
-'gripper' motor id set to 6
-```
-
-Followed by the next instruction:
-```
-Connect the controller board to the 'wrist_roll' motor only and press enter.
-```
-
-You can disconnect the 3-pin cable from the controller board but you can leave it connected to the gripper motor on the other end as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one.
-
-Repeat the operation for each motor as instructed.
-
-> [!TIP]
-> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board.
-
-When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm.
-
-#### Leader
-Do the same steps for the leader arm but modify the command or script accordingly.
-
-
-
-
-```bash
-python -m lerobot.setup_motors \
- --teleop.type=koch_leader \
- --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step
-```
-
-
-
-```python
-from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
-
-config = KochLeaderConfig(
- port="/dev/tty.usbmodem575E0031751",
- id="my_awesome_leader_arm",
-)
-leader = KochLeader(config)
-leader.setup_motors()
-```
-
-
-
-## Calibrate
-
-Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
-The calibration process is very important because it allows a neural network trained on one robot to work on another.
-
-#### Follower
-
-Run the following command or API example to calibrate the follower arm:
-
-
-
-
-```bash
-python -m lerobot.calibrate \
- --robot.type=koch_follower \
- --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
- --robot.id=my_awesome_follower_arm # <- Give the robot a unique name
-```
-
-
-
-```python
-from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower
-
-config = KochFollowerConfig(
- port="/dev/tty.usbmodem585A0076891",
- id="my_awesome_follower_arm",
-)
-
-follower = KochFollower(config)
-follower.connect(calibrate=False)
-follower.calibrate()
-follower.disconnect()
-```
-
-
-
-We unified the calibration method for most robots. Thus, the calibration steps for this Koch arm are the same as the steps for the SO100 and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video).
-
-#### Leader
-
-Do the same steps to calibrate the leader arm, run the following command or API example:
-
-
-
-
-```bash
-python -m lerobot.calibrate \
- --teleop.type=koch_leader \
- --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
- --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
-```
-
-
-
-```python
-from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader
-
-config = KochLeaderConfig(
- port="/dev/tty.usbmodem575E0031751",
- id="my_awesome_leader_arm",
-)
-
-leader = KochLeader(config)
-leader.connect(calibrate=False)
-leader.calibrate()
-leader.disconnect()
-```
-
-
-
-Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
-
-> [!TIP]
-> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/src/lerobot/robots/koch_follower/koch.mdx b/src/lerobot/robots/koch_follower/koch.mdx
new file mode 120000
index 000000000..ef43feb06
--- /dev/null
+++ b/src/lerobot/robots/koch_follower/koch.mdx
@@ -0,0 +1 @@
+../../../../docs/source/koch.mdx
\ No newline at end of file
diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py
index 1cfc6cf08..41a57828b 100644
--- a/src/lerobot/robots/koch_follower/koch_follower.py
+++ b/src/lerobot/robots/koch_follower/koch_follower.py
@@ -20,12 +20,12 @@ from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -94,6 +94,9 @@ class KochFollower(Robot):
self.bus.connect()
if not self.is_calibrated and calibrate:
+ logger.info(
+ "Mismatch between calibration values in the motor and the calibration file or no calibration file found"
+ )
self.calibrate()
for cam in self.cameras.values():
@@ -107,8 +110,17 @@ class KochFollower(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
- logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
+ if self.calibration:
+ # Calibration file exists, ask user whether to use it or run new calibration
+ user_input = input(
+ f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
+ )
+ if user_input.strip().lower() != "c":
+ logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
+ self.bus.write_calibration(self.calibration)
+ return
+ logger.info(f"\nRunning calibration of {self}")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
diff --git a/src/lerobot/robots/lekiwi/__init__.py b/src/lerobot/robots/lekiwi/__init__.py
index e3d10c5c1..ada2ff368 100644
--- a/src/lerobot/robots/lekiwi/__init__.py
+++ b/src/lerobot/robots/lekiwi/__init__.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig
from .lekiwi import LeKiwi
from .lekiwi_client import LeKiwiClient
diff --git a/src/lerobot/robots/lekiwi/config_lekiwi.py b/src/lerobot/robots/lekiwi/config_lekiwi.py
index f0f8c24b3..acaf5f0ec 100644
--- a/src/lerobot/robots/lekiwi/config_lekiwi.py
+++ b/src/lerobot/robots/lekiwi/config_lekiwi.py
@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
- # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
- # the number of motors in your follower arms.
- max_relative_target: int | None = None
+ # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
+ # names to the max_relative_target value for that motor.
+ max_relative_target: float | dict[str, float] | None = None
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
diff --git a/src/lerobot/robots/lekiwi/lekiwi.mdx b/src/lerobot/robots/lekiwi/lekiwi.mdx
deleted file mode 100644
index 61b1c05c1..000000000
--- a/src/lerobot/robots/lekiwi/lekiwi.mdx
+++ /dev/null
@@ -1,300 +0,0 @@
-# LeKiwi
-
-In the steps below, we explain how to assemble the LeKiwi mobile robot.
-
-## Source the parts
-
-Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts.
-And advise if it's your first time printing or if you don't own a 3D printer.
-
-### Wired version
-If you have the **wired** LeKiwi version, you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating.
-
-## Install software on Pi
-Now we have to set up the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board.
-
-### Install OS
-For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu.
-
-### Setup SSH
-After setting up your Pi, you should enable and set up [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can log in to the Pi from your laptop without requiring a screen, keyboard, and mouse on the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or, if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension.
-
-### Install LeRobot on Pi 🤗
-
-On your Raspberry Pi install LeRobot using our [Installation Guide](./installation)
-
-In addition to these instructions, you need to install the Feetech sdk on your Pi:
-```bash
-pip install -e ".[feetech]"
-```
-
-## Install LeRobot locally
-If you already have installed LeRobot on your laptop/pc you can skip this step; otherwise, please follow along as we do the same steps we did on the Pi.
-
-Follow our [Installation Guide](./installation)
-
-Great :hugs:! You are now done installing LeRobot, and we can begin assembling the SO100/SO101 arms and the mobile base :robot:.
-Every time you now want to use LeRobot, you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands.
-
-# Step-by-Step Assembly Instructions
-
-First, we will assemble the two SO100/SO101 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base. The instructions for assembling can be found on these two pages:
-
-- [Assemble SO101](./so101#step-by-step-assembly-instructions)
-- [Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi/blob/main/Assembly.md)
-
-### Find the USB ports associated with motor board
-
-To find the port for each bus servo adapter, run this script:
-```bash
-python -m lerobot.find_port
-```
-
-
-
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/tty.usbmodem575E0032081']
-Remove the USB cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/tty.usbmodem575E0032081
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your board.
-
-
-
-
-On Linux, you might need to give access to the USB ports by running:
-```bash
-sudo chmod 666 /dev/ttyACM0
-sudo chmod 666 /dev/ttyACM1
-```
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/ttyACM0']
-Remove the usb cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/ttyACM0
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/ttyACM0` corresponding to your board.
-
-
-
-
-### Configure motors
-The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9.
-
-You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7)
-
-```bash
-python -m lerobot.setup_motors \
- --robot.type=lekiwi \
- --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step
-```
-
-
-
-### Troubleshoot communication
-
-If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue.
-
-#### 1. Verify IP Address Configuration
-Make sure that the correct IP for the Pi is used in the commands or in your code. To check the Raspberry Pi's IP address, run (on the Pi command line):
-```bash
-hostname -I
-```
-
-#### 2. Check if Pi is reachable from laptop/pc
-Try pinging the Raspberry Pi from your laptop:
-```bach
-ping
-```
-
-If the ping fails:
-- Ensure the Pi is powered on and connected to the same network.
-- Check if SSH is enabled on the Pi.
-
-#### 3. Try SSH connection
-If you can't SSH into the Pi, it might not be properly connected. Use:
-```bash
-ssh @
-```
-If you get a connection error:
-- Ensure SSH is enabled on the Pi by running:
- ```bash
- sudo raspi-config
- ```
- Then navigate to: **Interfacing Options -> SSH** and enable it.
-
-### Calibration
-
-Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated.
-The calibration process is very important because it allows a neural network trained on one robot to work on another.
-
-### Calibrate follower arm (on mobile base)
-
-Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm:
-
-```bash
-python -m lerobot.calibrate \
- --robot.type=lekiwi \
- --robot.id=my_awesome_kiwi # <- Give the robot a unique name
-```
-
-We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video).
-
-### Wired version
-If you have the **wired** LeKiwi version, please run all commands on your laptop.
-
-### Calibrate leader arm
-Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the following command of API example on your laptop:
-
-
-
-```bash
-python -m lerobot.calibrate \
- --teleop.type=so100_leader \
- --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
- --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
-```
-
-
-
-```python
-from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
-
-config = SO100LeaderConfig(
- port="/dev/tty.usbmodem58760431551",
- id="my_awesome_leader_arm",
-)
-
-leader = SO100Leader(config)
-leader.connect(calibrate=False)
-leader.calibrate()
-leader.disconnect()
-```
-
-
-
-## Teleoperate LeKiwi
-
-> [!TIP]
-> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
-
-To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and this command:
-```bash
-python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi
-```
-
-Then on your laptop, also run `conda activate lerobot` and run the API example, make sure you set the correct `remote_ip` and `port` in `examples/lekiwi/teleoperate.py`.
-
-```bash
-python examples/lekiwi/teleoperate.py
-```
-
-You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
-
-| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
-| ---------- | ------------------ | ---------------------- |
-| Fast | 0.4 | 90 |
-| Medium | 0.25 | 60 |
-| Slow | 0.1 | 30 |
-
-
-| Key | Action |
-| --- | -------------- |
-| W | Move forward |
-| A | Move left |
-| S | Move backward |
-| D | Move right |
-| Z | Turn left |
-| X | Turn right |
-| R | Increase speed |
-| F | Decrease speed |
-
-> [!TIP]
-> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py).
-
-### Wired version
-If you have the **wired** LeKiwi version, please run all commands on your laptop.
-
-## Record a dataset
-
-Once you're familiar with teleoperation, you can record your first dataset.
-
-We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
-
-Add your token to the CLI by running this command:
-```bash
-huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
-```
-
-Then store your Hugging Face repository name in a variable:
-```bash
-HF_USER=$(huggingface-cli whoami | head -n 1)
-echo $HF_USER
-```
-
-Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`.
-```bash
-python examples/lekiwi/record.py
-```
-
-#### Dataset upload
-Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
-```bash
-echo https://huggingface.co/datasets/${HF_USER}/so101_test
-```
-Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
-
-You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
-
-#### Tips for gathering data
-
-Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images.
-
-In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions.
-
-Avoid adding too much variation too quickly, as it may hinder your results.
-
-If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset.
-
-#### Troubleshooting:
-- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
-
-
-## Replay an episode
-
-To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index.
-
-
-```bash
-python examples/lekiwi/replay.py
-```
-
-Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
-
-## Evaluate your policy
-
-To evaluate your policy run the `evaluate.py` API example, make sure to change `remote_ip`, `port`, model..
-
-```bash
-python examples/lekiwi/evaluate.py
-```
-
-> [!TIP]
-> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/src/lerobot/robots/lekiwi/lekiwi.mdx b/src/lerobot/robots/lekiwi/lekiwi.mdx
new file mode 120000
index 000000000..f65158998
--- /dev/null
+++ b/src/lerobot/robots/lekiwi/lekiwi.mdx
@@ -0,0 +1 @@
+../../../../docs/source/lekiwi.mdx
\ No newline at end of file
diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py
index ff1465d8b..357109cb0 100644
--- a/src/lerobot/robots/lekiwi/lekiwi.py
+++ b/src/lerobot/robots/lekiwi/lekiwi.py
@@ -23,12 +23,12 @@ from typing import Any
import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -114,6 +114,9 @@ class LeKiwi(Robot):
self.bus.connect()
if not self.is_calibrated and calibrate:
+ logger.info(
+ "Mismatch between calibration values in the motor and the calibration file or no calibration file found"
+ )
self.calibrate()
for cam in self.cameras.values():
@@ -127,6 +130,15 @@ class LeKiwi(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
+ if self.calibration:
+ # Calibration file exists, ask user whether to use it or run new calibration
+ user_input = input(
+ f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
+ )
+ if user_input.strip().lower() != "c":
+ logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
+ self.bus.write_calibration(self.calibration)
+ return
logger.info(f"\nRunning calibration of {self}")
motors = self.arm_motors + self.base_motors
diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py
index 0ce259bb6..19744e244 100644
--- a/src/lerobot/robots/lekiwi/lekiwi_client.py
+++ b/src/lerobot/robots/lekiwi/lekiwi_client.py
@@ -18,13 +18,13 @@ import base64
import json
import logging
from functools import cached_property
-from typing import Any, Dict, Optional, Tuple
+from typing import Any
import cv2
import numpy as np
-import zmq
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.constants import ACTION, OBS_STATE
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from .config_lekiwi import LeKiwiClientConfig
@@ -35,6 +35,9 @@ class LeKiwiClient(Robot):
name = "lekiwi_client"
def __init__(self, config: LeKiwiClientConfig):
+ import zmq
+
+ self._zmq = zmq
super().__init__(config)
self.config = config
self.id = config.id
@@ -117,6 +120,7 @@ class LeKiwiClient(Robot):
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
)
+ zmq = self._zmq
self.zmq_context = zmq.Context()
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH)
zmq_cmd_locator = f"tcp://{self.remote_ip}:{self.port_zmq_cmd}"
@@ -139,8 +143,9 @@ class LeKiwiClient(Robot):
def calibrate(self) -> None:
pass
- def _poll_and_get_latest_message(self) -> Optional[str]:
+ def _poll_and_get_latest_message(self) -> str | None:
"""Polls the ZMQ socket for a limited time and returns the latest message string."""
+ zmq = self._zmq
poller = zmq.Poller()
poller.register(self.zmq_observation_socket, zmq.POLLIN)
@@ -167,7 +172,7 @@ class LeKiwiClient(Robot):
return last_msg
- def _parse_observation_json(self, obs_string: str) -> Optional[Dict[str, Any]]:
+ def _parse_observation_json(self, obs_string: str) -> dict[str, Any] | None:
"""Parses the JSON observation string."""
try:
return json.loads(obs_string)
@@ -175,7 +180,7 @@ class LeKiwiClient(Robot):
logging.error(f"Error decoding JSON observation: {e}")
return None
- def _decode_image_from_b64(self, image_b64: str) -> Optional[np.ndarray]:
+ def _decode_image_from_b64(self, image_b64: str) -> np.ndarray | None:
"""Decodes a base64 encoded image string to an OpenCV image."""
if not image_b64:
return None
@@ -191,18 +196,18 @@ class LeKiwiClient(Robot):
return None
def _remote_state_from_obs(
- self, observation: Dict[str, Any]
- ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
+ self, observation: dict[str, Any]
+ ) -> tuple[dict[str, np.ndarray], dict[str, Any]]:
"""Extracts frames, and state from the parsed observation."""
flat_state = {key: observation.get(key, 0.0) for key in self._state_order}
state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32)
- obs_dict: Dict[str, Any] = {**flat_state, "observation.state": state_vec}
+ obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec}
# Decode images
- current_frames: Dict[str, np.ndarray] = {}
+ current_frames: dict[str, np.ndarray] = {}
for cam_name, image_b64 in observation.items():
if cam_name not in self._cameras_ft:
continue
@@ -212,7 +217,7 @@ class LeKiwiClient(Robot):
return current_frames, obs_dict
- def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]:
+ def _get_data(self) -> tuple[dict[str, np.ndarray], dict[str, Any], dict[str, Any]]:
"""
Polls the video socket for the latest observation data.
@@ -325,7 +330,7 @@ class LeKiwiClient(Robot):
actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32)
action_sent = {key: actions[i] for i, key in enumerate(self._state_order)}
- action_sent["action"] = actions
+ action_sent[ACTION] = actions
return action_sent
def disconnect(self):
diff --git a/src/lerobot/robots/reachy2/__init__.py b/src/lerobot/robots/reachy2/__init__.py
new file mode 100644
index 000000000..1a38fd03b
--- /dev/null
+++ b/src/lerobot/robots/reachy2/__init__.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_reachy2 import Reachy2RobotConfig
+from .robot_reachy2 import (
+ REACHY2_ANTENNAS_JOINTS,
+ REACHY2_L_ARM_JOINTS,
+ REACHY2_NECK_JOINTS,
+ REACHY2_R_ARM_JOINTS,
+ REACHY2_VEL,
+ Reachy2Robot,
+)
diff --git a/src/lerobot/robots/reachy2/configuration_reachy2.py b/src/lerobot/robots/reachy2/configuration_reachy2.py
new file mode 100644
index 000000000..aa25351c6
--- /dev/null
+++ b/src/lerobot/robots/reachy2/configuration_reachy2.py
@@ -0,0 +1,107 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+from lerobot.cameras import CameraConfig
+from lerobot.cameras.configs import ColorMode
+from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
+
+from ..config import RobotConfig
+
+
+@RobotConfig.register_subclass("reachy2")
+@dataclass
+class Reachy2RobotConfig(RobotConfig):
+ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
+ # Set this to a positive scalar to have the same value for all motors.
+ max_relative_target: float | None = None
+
+ # IP address of the Reachy 2 robot
+ ip_address: str | None = "localhost"
+
+ # If True, turn_off_smoothly() will be sent to the robot before disconnecting.
+ disable_torque_on_disconnect: bool = False
+
+ # Tag for external commands control
+ # Set to True if you use an external commands system to control the robot,
+ # such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
+ # If True, robot.send_action() will not send commands to the robot.
+ use_external_commands: bool = False
+
+ # Robot parts
+ # Set to False to not add the corresponding joints part to the robot list of joints.
+ # By default, all parts are set to True.
+ with_mobile_base: bool = True
+ with_l_arm: bool = True
+ with_r_arm: bool = True
+ with_neck: bool = True
+ with_antennas: bool = True
+
+ # Robot cameras
+ # Set to True if you want to use the corresponding cameras in the observations.
+ # By default, only the teleop cameras are used.
+ with_left_teleop_camera: bool = True
+ with_right_teleop_camera: bool = True
+ with_torso_camera: bool = False
+
+ cameras: dict[str, CameraConfig] = field(default_factory=dict)
+
+ def __post_init__(self) -> None:
+ # Add cameras with same ip_address as the robot
+ if self.with_left_teleop_camera:
+ self.cameras["teleop_left"] = Reachy2CameraConfig(
+ name="teleop",
+ image_type="left",
+ ip_address=self.ip_address,
+ fps=15,
+ width=640,
+ height=480,
+ color_mode=ColorMode.RGB,
+ )
+ if self.with_right_teleop_camera:
+ self.cameras["teleop_right"] = Reachy2CameraConfig(
+ name="teleop",
+ image_type="right",
+ ip_address=self.ip_address,
+ fps=15,
+ width=640,
+ height=480,
+ color_mode=ColorMode.RGB,
+ )
+ if self.with_torso_camera:
+ self.cameras["torso_rgb"] = Reachy2CameraConfig(
+ name="depth",
+ image_type="rgb",
+ ip_address=self.ip_address,
+ fps=15,
+ width=640,
+ height=480,
+ color_mode=ColorMode.RGB,
+ )
+
+ super().__post_init__()
+
+ if not (
+ self.with_mobile_base
+ or self.with_l_arm
+ or self.with_r_arm
+ or self.with_neck
+ or self.with_antennas
+ ):
+ raise ValueError(
+ "No Reachy2Robot part used.\n"
+ "At least one part of the robot must be set to True "
+ "(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
+ )
diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py
new file mode 100644
index 000000000..ecc488a79
--- /dev/null
+++ b/src/lerobot/robots/reachy2/robot_reachy2.py
@@ -0,0 +1,230 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from typing import Any
+
+import numpy as np
+from reachy2_sdk import ReachySDK
+
+from lerobot.cameras.utils import make_cameras_from_configs
+
+from ..robot import Robot
+from ..utils import ensure_safe_goal_position
+from .configuration_reachy2 import Reachy2RobotConfig
+
+# {lerobot_keys: reachy2_sdk_keys}
+REACHY2_NECK_JOINTS = {
+ "neck_yaw.pos": "head.neck.yaw",
+ "neck_pitch.pos": "head.neck.pitch",
+ "neck_roll.pos": "head.neck.roll",
+}
+
+REACHY2_ANTENNAS_JOINTS = {
+ "l_antenna.pos": "head.l_antenna",
+ "r_antenna.pos": "head.r_antenna",
+}
+
+REACHY2_R_ARM_JOINTS = {
+ "r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
+ "r_shoulder_roll.pos": "r_arm.shoulder.roll",
+ "r_elbow_yaw.pos": "r_arm.elbow.yaw",
+ "r_elbow_pitch.pos": "r_arm.elbow.pitch",
+ "r_wrist_roll.pos": "r_arm.wrist.roll",
+ "r_wrist_pitch.pos": "r_arm.wrist.pitch",
+ "r_wrist_yaw.pos": "r_arm.wrist.yaw",
+ "r_gripper.pos": "r_arm.gripper",
+}
+
+REACHY2_L_ARM_JOINTS = {
+ "l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
+ "l_shoulder_roll.pos": "l_arm.shoulder.roll",
+ "l_elbow_yaw.pos": "l_arm.elbow.yaw",
+ "l_elbow_pitch.pos": "l_arm.elbow.pitch",
+ "l_wrist_roll.pos": "l_arm.wrist.roll",
+ "l_wrist_pitch.pos": "l_arm.wrist.pitch",
+ "l_wrist_yaw.pos": "l_arm.wrist.yaw",
+ "l_gripper.pos": "l_arm.gripper",
+}
+
+REACHY2_VEL = {
+ "mobile_base.vx": "vx",
+ "mobile_base.vy": "vy",
+ "mobile_base.vtheta": "vtheta",
+}
+
+
+class Reachy2Robot(Robot):
+ """
+ [Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
+ """
+
+ config_class = Reachy2RobotConfig
+ name = "reachy2"
+
+ def __init__(self, config: Reachy2RobotConfig):
+ super().__init__(config)
+
+ self.config = config
+ self.robot_type = self.config.type
+ self.use_external_commands = self.config.use_external_commands
+
+ self.reachy: None | ReachySDK = None
+ self.cameras = make_cameras_from_configs(config.cameras)
+
+ self.logs: dict[str, float] = {}
+
+ self.joints_dict: dict[str, str] = self._generate_joints_dict()
+
+ @property
+ def observation_features(self) -> dict[str, Any]:
+ return {**self.motors_features, **self.camera_features}
+
+ @property
+ def action_features(self) -> dict[str, type]:
+ return self.motors_features
+
+ @property
+ def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
+ return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
+
+ @property
+ def motors_features(self) -> dict[str, type]:
+ if self.config.with_mobile_base:
+ return {
+ **dict.fromkeys(
+ self.joints_dict.keys(),
+ float,
+ ),
+ **dict.fromkeys(
+ REACHY2_VEL.keys(),
+ float,
+ ),
+ }
+ else:
+ return dict.fromkeys(self.joints_dict.keys(), float)
+
+ @property
+ def is_connected(self) -> bool:
+ return self.reachy.is_connected() if self.reachy is not None else False
+
+ def connect(self, calibrate: bool = False) -> None:
+ self.reachy = ReachySDK(self.config.ip_address)
+ if not self.is_connected:
+ raise ConnectionError()
+
+ for cam in self.cameras.values():
+ cam.connect()
+
+ self.configure()
+
+ def configure(self) -> None:
+ if self.reachy is not None:
+ self.reachy.turn_on()
+ self.reachy.reset_default_limits()
+
+ @property
+ def is_calibrated(self) -> bool:
+ return True
+
+ def calibrate(self) -> None:
+ pass
+
+ def _generate_joints_dict(self) -> dict[str, str]:
+ joints = {}
+ if self.config.with_neck:
+ joints.update(REACHY2_NECK_JOINTS)
+ if self.config.with_l_arm:
+ joints.update(REACHY2_L_ARM_JOINTS)
+ if self.config.with_r_arm:
+ joints.update(REACHY2_R_ARM_JOINTS)
+ if self.config.with_antennas:
+ joints.update(REACHY2_ANTENNAS_JOINTS)
+ return joints
+
+ def _get_state(self) -> dict[str, float]:
+ if self.reachy is not None:
+ pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
+ if not self.config.with_mobile_base:
+ return pos_dict
+ vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
+ return {**pos_dict, **vel_dict}
+ else:
+ return {}
+
+ def get_observation(self) -> dict[str, np.ndarray]:
+ obs_dict: dict[str, Any] = {}
+
+ # Read Reachy 2 state
+ before_read_t = time.perf_counter()
+ obs_dict.update(self._get_state())
+ self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
+
+ # Capture images from cameras
+ for cam_key, cam in self.cameras.items():
+ obs_dict[cam_key] = cam.async_read()
+
+ return obs_dict
+
+ def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
+ if self.reachy is not None:
+ if not self.is_connected:
+ raise ConnectionError()
+
+ before_write_t = time.perf_counter()
+
+ vel = {}
+ goal_pos = {}
+ for key, val in action.items():
+ if key not in self.joints_dict:
+ if key not in REACHY2_VEL:
+ raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
+ else:
+ vel[REACHY2_VEL[key]] = float(val)
+ else:
+ if not self.use_external_commands and self.config.max_relative_target is not None:
+ goal_pos[key] = float(val)
+ goal_present_pos = {
+ key: (
+ goal_pos[key],
+ self.reachy.joints[self.joints_dict[key]].present_position,
+ )
+ }
+ safe_goal_pos = ensure_safe_goal_position(
+ goal_present_pos, float(self.config.max_relative_target)
+ )
+ val = safe_goal_pos[key]
+ self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
+
+ if self.config.with_mobile_base:
+ self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
+
+ # We don't send the goal positions if we control Reachy 2 externally
+ if not self.use_external_commands:
+ self.reachy.send_goal_positions()
+ if self.config.with_mobile_base:
+ self.reachy.mobile_base.send_speed_command()
+
+ self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
+ return action
+
+ def disconnect(self) -> None:
+ if self.reachy is not None:
+ for cam in self.cameras.values():
+ cam.disconnect()
+ if self.config.disable_torque_on_disconnect:
+ self.reachy.turn_off_smoothly()
+ self.reachy.disconnect()
diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py
index 6820645cc..5e88b915b 100644
--- a/src/lerobot/robots/robot.py
+++ b/src/lerobot/robots/robot.py
@@ -13,13 +13,14 @@
# limitations under the License.
import abc
+import builtins
from pathlib import Path
-from typing import Any, Type
+from typing import Any
import draccus
-from lerobot.constants import HF_LEROBOT_CALIBRATION, ROBOTS
from lerobot.motors import MotorCalibration
+from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS
from .config import RobotConfig
@@ -39,7 +40,7 @@ class Robot(abc.ABC):
"""
# Set these in ALL subclasses
- config_class: Type[RobotConfig]
+ config_class: builtins.type[RobotConfig]
name: str
def __init__(self, config: RobotConfig):
diff --git a/src/lerobot/robots/so100_follower/__init__.py b/src/lerobot/robots/so100_follower/__init__.py
index 63c3e1c17..5dc43ac3b 100644
--- a/src/lerobot/robots/so100_follower/__init__.py
+++ b/src/lerobot/robots/so100_follower/__init__.py
@@ -1,3 +1,18 @@
-from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .config_so100_follower import SO100FollowerConfig
from .so100_follower import SO100Follower
-from .so100_follower_end_effector import SO100FollowerEndEffector
diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py
index 7cd23d340..272b8c43f 100644
--- a/src/lerobot/robots/so100_follower/config_so100_follower.py
+++ b/src/lerobot/robots/so100_follower/config_so100_follower.py
@@ -1,4 +1,6 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,44 +30,12 @@ class SO100FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
- # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
- # the number of motors in your follower arms.
- max_relative_target: int | None = None
+ # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
+ # names to the max_relative_target value for that motor.
+ max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
-
-
-@RobotConfig.register_subclass("so100_follower_end_effector")
-@dataclass
-class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
- """Configuration for the SO100FollowerEndEffector robot."""
-
- # Path to URDF file for kinematics
- # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
- # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
- urdf_path: str | None = None
-
- # End-effector frame name in the URDF
- target_frame_name: str = "gripper_frame_link"
-
- # Default bounds for the end-effector position (in meters)
- end_effector_bounds: dict[str, list[float]] = field(
- default_factory=lambda: {
- "min": [-1.0, -1.0, -1.0], # min x, y, z
- "max": [1.0, 1.0, 1.0], # max x, y, z
- }
- )
-
- max_gripper_pos: float = 50
-
- end_effector_step_sizes: dict[str, float] = field(
- default_factory=lambda: {
- "x": 0.02,
- "y": 0.02,
- "z": 0.02,
- }
- )
diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py
new file mode 100644
index 000000000..87e832db6
--- /dev/null
+++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py
@@ -0,0 +1,610 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+from typing import Any
+
+import numpy as np
+
+from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.model.kinematics import RobotKinematics
+from lerobot.processor import (
+ EnvTransition,
+ ObservationProcessorStep,
+ ProcessorStep,
+ ProcessorStepRegistry,
+ RobotAction,
+ RobotActionProcessorStep,
+ TransitionKey,
+)
+from lerobot.utils.rotation import Rotation
+
+
+@ProcessorStepRegistry.register("ee_reference_and_delta")
+@dataclass
+class EEReferenceAndDelta(RobotActionProcessorStep):
+ """
+ Computes a target end-effector pose from a relative delta command.
+
+ This step takes a desired change in position and orientation (`target_*`) and applies it to a
+ reference end-effector pose to calculate an absolute target pose. The reference pose is derived
+ from the current robot joint positions using forward kinematics.
+
+ The processor can operate in two modes:
+ 1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action
+ is first enabled. Subsequent commands are relative to this fixed reference.
+ 2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at
+ every step.
+
+ Attributes:
+ kinematics: The robot's kinematic model for forward kinematics.
+ end_effector_step_sizes: A dictionary scaling the input delta commands.
+ motor_names: A list of motor names required for forward kinematics.
+ use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the
+ current pose as the reference.
+ reference_ee_pose: Internal state storing the latched reference pose.
+ _prev_enabled: Internal state to detect the rising edge of the enable signal.
+ _command_when_disabled: Internal state to hold the last command while disabled.
+ """
+
+ kinematics: RobotKinematics
+ end_effector_step_sizes: dict
+ motor_names: list[str]
+ use_latched_reference: bool = (
+ True # If True, latch reference on enable; if False, always use current pose
+ )
+ use_ik_solution: bool = False
+
+ reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
+ _prev_enabled: bool = field(default=False, init=False, repr=False)
+ _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
+
+ def action(self, action: RobotAction) -> RobotAction:
+ observation = self.transition.get(TransitionKey.OBSERVATION).copy()
+
+ if observation is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ if self.use_ik_solution and "IK_solution" in self.transition.get(TransitionKey.COMPLEMENTARY_DATA):
+ q_raw = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
+ else:
+ q_raw = np.array(
+ [
+ float(v)
+ for k, v in observation.items()
+ if isinstance(k, str)
+ and k.endswith(".pos")
+ and k.removesuffix(".pos") in self.motor_names
+ ],
+ dtype=float,
+ )
+
+ if q_raw is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ # Current pose from FK on measured joints
+ t_curr = self.kinematics.forward_kinematics(q_raw)
+
+ enabled = bool(action.pop("enabled"))
+ tx = float(action.pop("target_x"))
+ ty = float(action.pop("target_y"))
+ tz = float(action.pop("target_z"))
+ wx = float(action.pop("target_wx"))
+ wy = float(action.pop("target_wy"))
+ wz = float(action.pop("target_wz"))
+ gripper_vel = float(action.pop("gripper_vel"))
+
+ desired = None
+
+ if enabled:
+ ref = t_curr
+ if self.use_latched_reference:
+ # Latched reference mode: latch reference at the rising edge
+ if not self._prev_enabled or self.reference_ee_pose is None:
+ self.reference_ee_pose = t_curr.copy()
+ ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
+
+ delta_p = np.array(
+ [
+ tx * self.end_effector_step_sizes["x"],
+ ty * self.end_effector_step_sizes["y"],
+ tz * self.end_effector_step_sizes["z"],
+ ],
+ dtype=float,
+ )
+ r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
+ desired = np.eye(4, dtype=float)
+ desired[:3, :3] = ref[:3, :3] @ r_abs
+ desired[:3, 3] = ref[:3, 3] + delta_p
+
+ self._command_when_disabled = desired.copy()
+ else:
+ # While disabled, keep sending the same command to avoid drift.
+ if self._command_when_disabled is None:
+ # If we've never had an enabled command yet, freeze current FK pose once.
+ self._command_when_disabled = t_curr.copy()
+ desired = self._command_when_disabled.copy()
+
+ # Write action fields
+ pos = desired[:3, 3]
+ tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
+ action["ee.x"] = float(pos[0])
+ action["ee.y"] = float(pos[1])
+ action["ee.z"] = float(pos[2])
+ action["ee.wx"] = float(tw[0])
+ action["ee.wy"] = float(tw[1])
+ action["ee.wz"] = float(tw[2])
+ action["ee.gripper_vel"] = gripper_vel
+
+ self._prev_enabled = enabled
+ return action
+
+ def reset(self):
+ """Resets the internal state of the processor."""
+ self._prev_enabled = False
+ self.reference_ee_pose = None
+ self._command_when_disabled = None
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ for feat in [
+ "enabled",
+ "target_x",
+ "target_y",
+ "target_z",
+ "target_wx",
+ "target_wy",
+ "target_wz",
+ "gripper_vel",
+ ]:
+ features[PipelineFeatureType.ACTION].pop(f"{feat}", None)
+
+ for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_vel"]:
+ features[PipelineFeatureType.ACTION][f"ee.{feat}"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+
+ return features
+
+
+@ProcessorStepRegistry.register("ee_bounds_and_safety")
+@dataclass
+class EEBoundsAndSafety(RobotActionProcessorStep):
+ """
+ Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
+
+ This step ensures that the target end-effector pose remains within a safe operational workspace.
+ It also moderates the command to prevent large, sudden movements between consecutive steps.
+
+ Attributes:
+ end_effector_bounds: A dictionary with "min" and "max" keys for position clipping.
+ max_ee_step_m: The maximum allowed change in position (in meters) between steps.
+ _last_pos: Internal state storing the last commanded position.
+ """
+
+ end_effector_bounds: dict
+ max_ee_step_m: float = 0.05
+ _last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
+
+ def action(self, action: RobotAction) -> RobotAction:
+ x = action["ee.x"]
+ y = action["ee.y"]
+ z = action["ee.z"]
+ wx = action["ee.wx"]
+ wy = action["ee.wy"]
+ wz = action["ee.wz"]
+ # TODO(Steven): ee.gripper_vel does not need to be bounded
+
+ if None in (x, y, z, wx, wy, wz):
+ raise ValueError(
+ "Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
+ )
+
+ pos = np.array([x, y, z], dtype=float)
+ twist = np.array([wx, wy, wz], dtype=float)
+
+ # Clip position
+ pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
+
+ # Check for jumps in position
+ if self._last_pos is not None:
+ dpos = pos - self._last_pos
+ n = float(np.linalg.norm(dpos))
+ if n > self.max_ee_step_m and n > 0:
+ pos = self._last_pos + dpos * (self.max_ee_step_m / n)
+ raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
+
+ self._last_pos = pos
+
+ action["ee.x"] = float(pos[0])
+ action["ee.y"] = float(pos[1])
+ action["ee.z"] = float(pos[2])
+ action["ee.wx"] = float(twist[0])
+ action["ee.wy"] = float(twist[1])
+ action["ee.wz"] = float(twist[2])
+ return action
+
+ def reset(self):
+ """Resets the last known position and orientation."""
+ self._last_pos = None
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
+
+
+@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
+@dataclass
+class InverseKinematicsEEToJoints(RobotActionProcessorStep):
+ """
+ Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
+
+ This step translates a Cartesian command (position and orientation of the end-effector) into
+ the corresponding joint-space commands for each motor.
+
+ Attributes:
+ kinematics: The robot's kinematic model for inverse kinematics.
+ motor_names: A list of motor names for which to compute joint positions.
+ q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver.
+ initial_guess_current_joints: If True, use the robot's current joint state as the IK guess.
+ If False, use the solution from the previous step.
+ """
+
+ kinematics: RobotKinematics
+ motor_names: list[str]
+ q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
+ initial_guess_current_joints: bool = True
+
+ def action(self, action: RobotAction) -> RobotAction:
+ x = action.pop("ee.x")
+ y = action.pop("ee.y")
+ z = action.pop("ee.z")
+ wx = action.pop("ee.wx")
+ wy = action.pop("ee.wy")
+ wz = action.pop("ee.wz")
+ gripper_pos = action.pop("ee.gripper_pos")
+
+ if None in (x, y, z, wx, wy, wz, gripper_pos):
+ raise ValueError(
+ "Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action"
+ )
+
+ observation = self.transition.get(TransitionKey.OBSERVATION).copy()
+ if observation is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ q_raw = np.array(
+ [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
+ dtype=float,
+ )
+ if q_raw is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ if self.initial_guess_current_joints: # Use current joints as initial guess
+ self.q_curr = q_raw
+ else: # Use previous ik solution as initial guess
+ if self.q_curr is None:
+ self.q_curr = q_raw
+
+ # Build desired 4x4 transform from pos + rotvec (twist)
+ t_des = np.eye(4, dtype=float)
+ t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
+ t_des[:3, 3] = [x, y, z]
+
+ # Compute inverse kinematics
+ q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
+ self.q_curr = q_target
+
+ # TODO: This is sentitive to order of motor_names = q_target mapping
+ for i, name in enumerate(self.motor_names):
+ if name != "gripper":
+ action[f"{name}.pos"] = float(q_target[i])
+ else:
+ action["gripper.pos"] = float(gripper_pos)
+
+ return action
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
+ features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
+
+ for name in self.motor_names:
+ features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+
+ return features
+
+ def reset(self):
+ """Resets the initial guess for the IK solver."""
+ self.q_curr = None
+
+
+@ProcessorStepRegistry.register("gripper_velocity_to_joint")
+@dataclass
+class GripperVelocityToJoint(RobotActionProcessorStep):
+ """
+ Converts a gripper velocity command into a target gripper joint position.
+
+ This step integrates a normalized velocity command over time to produce a position command,
+ taking the current gripper position as a starting point. It also supports a discrete mode
+ where integer actions map to open, close, or no-op.
+
+ Attributes:
+ motor_names: A list of motor names, which must include 'gripper'.
+ speed_factor: A scaling factor to convert the normalized velocity command to a position change.
+ clip_min: The minimum allowed gripper joint position.
+ clip_max: The maximum allowed gripper joint position.
+ discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
+ """
+
+ speed_factor: float = 20.0
+ clip_min: float = 0.0
+ clip_max: float = 100.0
+ discrete_gripper: bool = False
+
+ def action(self, action: RobotAction) -> RobotAction:
+ observation = self.transition.get(TransitionKey.OBSERVATION).copy()
+
+ gripper_vel = action.pop("ee.gripper_vel")
+
+ if observation is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ q_raw = np.array(
+ [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
+ dtype=float,
+ )
+ if q_raw is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ if self.discrete_gripper:
+ # Discrete gripper actions are in [0, 1, 2]
+ # 0: open, 1: close, 2: stay
+ # We need to shift them to [-1, 0, 1] and then scale them to clip_max
+ gripper_vel = (gripper_vel - 1) * self.clip_max
+
+ # Compute desired gripper position
+ delta = gripper_vel * float(self.speed_factor)
+ # TODO: This assumes gripper is the last specified joint in the robot
+ gripper_pos = float(np.clip(q_raw[-1] + delta, self.clip_min, self.clip_max))
+ action["ee.gripper_pos"] = gripper_pos
+
+ return action
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ features[PipelineFeatureType.ACTION].pop("ee.gripper_vel", None)
+ features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+
+ return features
+
+
+def compute_forward_kinematics_joints_to_ee(
+ joints: dict[str, Any], kinematics: RobotKinematics, motor_names: list[str]
+) -> dict[str, Any]:
+ motor_joint_values = [joints[f"{n}.pos"] for n in motor_names]
+
+ q = np.array(motor_joint_values, dtype=float)
+ t = kinematics.forward_kinematics(q)
+ pos = t[:3, 3]
+ tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
+ gripper_pos = joints["gripper.pos"]
+ for n in motor_names:
+ joints.pop(f"{n}.pos")
+ joints["ee.x"] = float(pos[0])
+ joints["ee.y"] = float(pos[1])
+ joints["ee.z"] = float(pos[2])
+ joints["ee.wx"] = float(tw[0])
+ joints["ee.wy"] = float(tw[1])
+ joints["ee.wz"] = float(tw[2])
+ joints["ee.gripper_pos"] = float(gripper_pos)
+ return joints
+
+
+@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_observation")
+@dataclass
+class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep):
+ """
+ Computes the end-effector pose from joint positions using forward kinematics (FK).
+
+ This step is typically used to add the robot's Cartesian pose to the observation space,
+ which can be useful for visualization or as an input to a policy.
+
+ Attributes:
+ kinematics: The robot's kinematic model.
+ """
+
+ kinematics: RobotKinematics
+ motor_names: list[str]
+
+ def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
+ return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names)
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ # We only use the ee pose in the dataset, so we don't need the joint positions
+ for n in self.motor_names:
+ features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None)
+ # We specify the dataset features of this step that we want to be stored in the dataset
+ for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
+ features[PipelineFeatureType.OBSERVATION][f"ee.{k}"] = PolicyFeature(
+ type=FeatureType.STATE, shape=(1,)
+ )
+ return features
+
+
+@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_action")
+@dataclass
+class ForwardKinematicsJointsToEEAction(RobotActionProcessorStep):
+ """
+ Computes the end-effector pose from joint positions using forward kinematics (FK).
+
+ This step is typically used to add the robot's Cartesian pose to the observation space,
+ which can be useful for visualization or as an input to a policy.
+
+ Attributes:
+ kinematics: The robot's kinematic model.
+ """
+
+ kinematics: RobotKinematics
+ motor_names: list[str]
+
+ def action(self, action: RobotAction) -> RobotAction:
+ return compute_forward_kinematics_joints_to_ee(action, self.kinematics, self.motor_names)
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ # We only use the ee pose in the dataset, so we don't need the joint positions
+ for n in self.motor_names:
+ features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None)
+ # We specify the dataset features of this step that we want to be stored in the dataset
+ for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
+ features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature(
+ type=FeatureType.STATE, shape=(1,)
+ )
+ return features
+
+
+@ProcessorStepRegistry.register(name="forward_kinematics_joints_to_ee")
+@dataclass
+class ForwardKinematicsJointsToEE(ProcessorStep):
+ kinematics: RobotKinematics
+ motor_names: list[str]
+
+ def __post_init__(self):
+ self.joints_to_ee_action_processor = ForwardKinematicsJointsToEEAction(
+ kinematics=self.kinematics, motor_names=self.motor_names
+ )
+ self.joints_to_ee_observation_processor = ForwardKinematicsJointsToEEObservation(
+ kinematics=self.kinematics, motor_names=self.motor_names
+ )
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ if transition.get(TransitionKey.ACTION) is not None:
+ transition = self.joints_to_ee_action_processor(transition)
+ if transition.get(TransitionKey.OBSERVATION) is not None:
+ transition = self.joints_to_ee_observation_processor(transition)
+ return transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ if features[PipelineFeatureType.ACTION] is not None:
+ features = self.joints_to_ee_action_processor.transform_features(features)
+ if features[PipelineFeatureType.OBSERVATION] is not None:
+ features = self.joints_to_ee_observation_processor.transform_features(features)
+ return features
+
+
+@ProcessorStepRegistry.register("inverse_kinematics_rl_step")
+@dataclass
+class InverseKinematicsRLStep(ProcessorStep):
+ """
+ Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
+
+ This is modified from the InverseKinematicsEEToJoints step to be used in the RL pipeline.
+ """
+
+ kinematics: RobotKinematics
+ motor_names: list[str]
+ q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
+ initial_guess_current_joints: bool = True
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ new_transition = dict(transition)
+ action = new_transition.get(TransitionKey.ACTION)
+ if action is None:
+ raise ValueError("Action is required for InverseKinematicsEEToJoints")
+ action = dict(action)
+
+ x = action.pop("ee.x")
+ y = action.pop("ee.y")
+ z = action.pop("ee.z")
+ wx = action.pop("ee.wx")
+ wy = action.pop("ee.wy")
+ wz = action.pop("ee.wz")
+ gripper_pos = action.pop("ee.gripper_pos")
+
+ if None in (x, y, z, wx, wy, wz, gripper_pos):
+ raise ValueError(
+ "Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action"
+ )
+
+ observation = new_transition.get(TransitionKey.OBSERVATION).copy()
+ if observation is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ q_raw = np.array(
+ [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
+ dtype=float,
+ )
+ if q_raw is None:
+ raise ValueError("Joints observation is require for computing robot kinematics")
+
+ if self.initial_guess_current_joints: # Use current joints as initial guess
+ self.q_curr = q_raw
+ else: # Use previous ik solution as initial guess
+ if self.q_curr is None:
+ self.q_curr = q_raw
+
+ # Build desired 4x4 transform from pos + rotvec (twist)
+ t_des = np.eye(4, dtype=float)
+ t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
+ t_des[:3, 3] = [x, y, z]
+
+ # Compute inverse kinematics
+ q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
+ self.q_curr = q_target
+
+ # TODO: This is sentitive to order of motor_names = q_target mapping
+ for i, name in enumerate(self.motor_names):
+ if name != "gripper":
+ action[f"{name}.pos"] = float(q_target[i])
+ else:
+ action["gripper.pos"] = float(gripper_pos)
+
+ new_transition[TransitionKey.ACTION] = action
+ complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
+ complementary_data["IK_solution"] = q_target
+ new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
+ return new_transition
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
+ features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
+
+ for name in self.motor_names:
+ features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+
+ return features
+
+ def reset(self):
+ """Resets the initial guess for the IK solver."""
+ self.q_curr = None
diff --git a/src/lerobot/robots/so100_follower/so100.mdx b/src/lerobot/robots/so100_follower/so100.mdx
deleted file mode 100644
index f5eea6aef..000000000
--- a/src/lerobot/robots/so100_follower/so100.mdx
+++ /dev/null
@@ -1,489 +0,0 @@
-# SO-100
-
-In the steps below, we explain how to assemble the SO-100 robot.
-
-## Source the parts
-
-Follow this [README](https://github.com/TheRobotStudio/SO-ARM100/blob/main/SO100.md). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. And advise if it's your first time printing or if you don't own a 3D printer.
-
-## Install LeRobot 🤗
-
-To install LeRobot, follow our [Installation Guide](./installation)
-
-In addition to these instructions, you need to install the Feetech SDK:
-```bash
-pip install -e ".[feetech]"
-```
-
-## Configure the motors
-
-**Note:**
-Unlike the SO-101, the motor connectors are not easily accessible once the arm is assembled, so the configuration step must be done beforehand.
-
-### 1. Find the USB ports associated with each arm
-
-To find the port for each bus servo adapter, run this script:
-```bash
-python -m lerobot.find_port
-```
-
-
-
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
-Remove the USB cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/tty.usbmodem575E0032081
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
-
-
-
-
-On Linux, you might need to give access to the USB ports by running:
-```bash
-sudo chmod 666 /dev/ttyACM0
-sudo chmod 666 /dev/ttyACM1
-```
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/ttyACM0', '/dev/ttyACM1']
-Remove the usb cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/ttyACM1
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm.
-
-
-
-
-### 2. Set the motors ids and baudrates
-
-Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate.
-
-To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once.
-
-If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match.
-
-#### Follower
-
-Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter.
-
-For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm.
-
-
-
-
-```bash
-python -m lerobot.setup_motors \
- --robot.type=so100_follower \
- --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
-```
-
-
-
-```python
-from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig
-
-config = SO100FollowerConfig(
- port="/dev/tty.usbmodem585A0076841",
- id="my_awesome_follower_arm",
-)
-follower = SO100Follower(config)
-follower.setup_motors()
-```
-
-
-
-You should see the following instruction
-```
-Connect the controller board to the 'gripper' motor only and press enter.
-```
-
-As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor.
-
-
-Troubleshooting
-
- If you get an error at that point, check your cables and make sure they are plugged in properly:
-
-
Power supply
-
USB cable between your computer and the controller board
-
The 3-pin cable from the controller board to the motor
-
-
-If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
-
-
-You should then see the following message:
-```
-'gripper' motor id set to 6
-```
-
-Followed by the next instruction:
-```
-Connect the controller board to the 'wrist_roll' motor only and press enter.
-```
-
-You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one.
-
-Repeat the operation for each motor as instructed.
-
-> [!TIP]
-> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board.
-
-When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm.
-
-#### Leader
-Do the same steps for the leader arm.
-
-
-
-```bash
-python -m lerobot.setup_motors \
- --teleop.type=so100_leader \
- --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
-```
-
-
-
-```python
-from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
-
-config = SO100LeaderConfig(
- port="/dev/tty.usbmodem585A0076841",
- id="my_awesome_leader_arm",
-)
-leader = SO100Leader(config)
-leader.setup_motors()
-```
-
-
-
-## Step-by-Step Assembly Instructions
-
-## Remove the gears of the 6 leader motors
-
-
-Video removing gears
-
-
-
-
-
-
-
-Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
-
-### Clean Parts
-Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
-
-### Additional Guidance
-
-
-Video assembling arms
-
-
-
-
-
-
-
-**Note:**
-This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour.
-
----
-
-### First Motor
-
-**Step 2: Insert Wires**
-- Insert two wires into the first motor.
-
-
-
-**Step 3: Install in Base**
-- Place the first motor into the base.
-
-
-
-**Step 4: Secure Motor**
-- Fasten the motor with 4 screws. Two from the bottom and two from top.
-
-**Step 5: Attach Motor Holder**
-- Slide over the first motor holder and fasten it using two screws (one on each side).
-
-
-
-**Step 6: Attach Motor Horns**
-- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears.
-
-
-
-
- Video adding motor horn
-
-
-
-**Step 7: Attach Shoulder Part**
-- Route one wire to the back of the robot and the other to the left or towards you (see photo).
-- Attach the shoulder part.
-
-
-
-**Step 8: Secure Shoulder**
-- Tighten the shoulder part with 4 screws on top and 4 on the bottom
-*(access bottom holes by turning the shoulder).*
-
----
-
-### Second Motor Assembly
-
-**Step 9: Install Motor 2**
-- Slide the second motor in from the top and link the wire from motor 1 to motor 2.
-
-
-
-**Step 10: Attach Shoulder Holder**
-- Add the shoulder motor holder.
-- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo).
-- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor.
-
-
-
-
-
-
-
-**Step 11: Secure Motor 2**
-- Fasten the second motor with 4 screws.
-
-**Step 12: Attach Motor Horn**
-- Attach both motor horns to motor 2, again use the horn screw.
-
-**Step 13: Attach Base**
-- Install the base attachment using 2 screws.
-
-
-
-**Step 14: Attach Upper Arm**
-- Attach the upper arm with 4 screws on each side.
-
-
-
----
-
-### Third Motor Assembly
-
-**Step 15: Install Motor 3**
-- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws.
-
-**Step 16: Attach Motor Horn**
-- Attach both motor horns to motor 3 and secure one again with a horn screw.
-
-
-
-**Step 17: Attach Forearm**
-- Connect the forearm to motor 3 using 4 screws on each side.
-
-
-
----
-
-### Fourth Motor Assembly
-
-**Step 18: Install Motor 4**
-- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw.
-
-
-
-
-
-
-**Step 19: Attach Motor Holder 4**
-- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo).
-
-
-
-**Step 20: Secure Motor 4 & Attach Horn**
-- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw.
-
-
-
----
-
-### Wrist Assembly
-
-**Step 21: Install Motor 5**
-- Insert motor 5 into the wrist holder and secure it with 2 front screws.
-
-
-
-**Step 22: Attach Wrist**
-- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper.
-- Secure the wrist to motor 4 using 4 screws on both sides.
-
-
-
-**Step 23: Attach Wrist Horn**
-- Install only one motor horn on the wrist motor and secure it with a horn screw.
-
-
-
----
-
-### Follower Configuration
-
-**Step 24: Attach Gripper**
-- Attach the gripper to motor 5.
-
-
-
-**Step 25: Install Gripper Motor**
-- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side.
-
-
-
-**Step 26: Attach Gripper Horn & Claw**
-- Attach the motor horns and again use a horn screw.
-- Install the gripper claw and secure it with 4 screws on both sides.
-
-
-
-**Step 27: Mount Controller**
-- Attach the motor controller to the back of the robot.
-
-
-
-
-
-
-*Assembly complete – proceed to Leader arm assembly.*
-
----
-
-### Leader Configuration
-
-For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors.
-
-**Step 24: Attach Leader Holder**
-- Mount the leader holder onto the wrist and secure it with a screw.
-
-
-
-**Step 25: Attach Handle**
-- Attach the handle to motor 5 using 4 screws.
-
-
-
-**Step 26: Install Gripper Motor**
-- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire.
-
-
-
-**Step 27: Attach Trigger**
-- Attach the follower trigger with 4 screws.
-
-
-
-**Step 28: Mount Controller**
-- Attach the motor controller to the back of the robot.
-
-
-
-
-
-
-## Calibrate
-
-Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
-The calibration process is very important because it allows a neural network trained on one robot to work on another.
-
-#### Follower
-
-Run the following command or API example to calibrate the follower arm:
-
-
-
-
-```bash
-python -m lerobot.calibrate \
- --robot.type=so100_follower \
- --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
- --robot.id=my_awesome_follower_arm # <- Give the robot a unique name
-```
-
-
-
-```python
-from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower
-
-config = SO100FollowerConfig(
- port="/dev/tty.usbmodem585A0076891",
- id="my_awesome_follower_arm",
-)
-
-follower = SO100Follower(config)
-follower.connect(calibrate=False)
-follower.calibrate()
-follower.disconnect()
-```
-
-
-
-We unified the calibration method for most robots. Thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video)
-
-#### Leader
-
-Do the same steps to calibrate the leader arm, run the following command or API example:
-
-
-
-
-```bash
-python -m lerobot.calibrate \
- --teleop.type=so100_leader \
- --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
- --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
-```
-
-
-
-```python
-from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader
-
-config = SO100LeaderConfig(
- port="/dev/tty.usbmodem58760431551",
- id="my_awesome_leader_arm",
-)
-
-leader = SO100Leader(config)
-leader.connect(calibrate=False)
-leader.calibrate()
-leader.disconnect()
-```
-
-
-
-Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
-
-> [!TIP]
-> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/src/lerobot/robots/so100_follower/so100.mdx b/src/lerobot/robots/so100_follower/so100.mdx
new file mode 120000
index 000000000..ad1154e75
--- /dev/null
+++ b/src/lerobot/robots/so100_follower/so100.mdx
@@ -0,0 +1 @@
+../../../../docs/source/so100.mdx
\ No newline at end of file
diff --git a/src/lerobot/robots/so100_follower/so100_follower.py b/src/lerobot/robots/so100_follower/so100_follower.py
index e5da6bc1a..d660ebed4 100644
--- a/src/lerobot/robots/so100_follower/so100_follower.py
+++ b/src/lerobot/robots/so100_follower/so100_follower.py
@@ -20,12 +20,12 @@ from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -92,6 +92,9 @@ class SO100Follower(Robot):
self.bus.connect()
if not self.is_calibrated and calibrate:
+ logger.info(
+ "Mismatch between calibration values in the motor and the calibration file or no calibration file found"
+ )
self.calibrate()
for cam in self.cameras.values():
@@ -105,6 +108,16 @@ class SO100Follower(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
+ if self.calibration:
+ # Calibration file exists, ask user whether to use it or run new calibration
+ user_input = input(
+ f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
+ )
+ if user_input.strip().lower() != "c":
+ logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
+ self.bus.write_calibration(self.calibration)
+ return
+
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
@@ -148,6 +161,11 @@ class SO100Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
+ if motor == "gripper":
+ self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout
+ self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
+ self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
+
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
diff --git a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py b/src/lerobot/robots/so100_follower/so100_follower_end_effector.py
deleted file mode 100644
index 5fe2993cb..000000000
--- a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# !/usr/bin/env python
-
-# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import time
-from typing import Any
-
-import numpy as np
-
-from lerobot.cameras import make_cameras_from_configs
-from lerobot.errors import DeviceNotConnectedError
-from lerobot.model.kinematics import RobotKinematics
-from lerobot.motors import Motor, MotorNormMode
-from lerobot.motors.feetech import FeetechMotorsBus
-
-from . import SO100Follower
-from .config_so100_follower import SO100FollowerEndEffectorConfig
-
-logger = logging.getLogger(__name__)
-
-
-class SO100FollowerEndEffector(SO100Follower):
- """
- SO100Follower robot with end-effector space control.
-
- This robot inherits from SO100Follower but transforms actions from
- end-effector space to joint space before sending them to the motors.
- """
-
- config_class = SO100FollowerEndEffectorConfig
- name = "so100_follower_end_effector"
-
- def __init__(self, config: SO100FollowerEndEffectorConfig):
- super().__init__(config)
- self.bus = FeetechMotorsBus(
- port=self.config.port,
- motors={
- "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
- "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
- "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
- "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
- "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
- "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
- },
- calibration=self.calibration,
- )
-
- self.cameras = make_cameras_from_configs(config.cameras)
-
- self.config = config
-
- # Initialize the kinematics module for the so100 robot
- if self.config.urdf_path is None:
- raise ValueError(
- "urdf_path must be provided in the configuration for end-effector control. "
- "Please set urdf_path in your SO100FollowerEndEffectorConfig."
- )
-
- self.kinematics = RobotKinematics(
- urdf_path=self.config.urdf_path,
- target_frame_name=self.config.target_frame_name,
- )
-
- # Store the bounds for end-effector position
- self.end_effector_bounds = self.config.end_effector_bounds
-
- self.current_ee_pos = None
- self.current_joint_pos = None
-
- @property
- def action_features(self) -> dict[str, Any]:
- """
- Define action features for end-effector control.
- Returns dictionary with dtype, shape, and names.
- """
- return {
- "dtype": "float32",
- "shape": (4,),
- "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
- }
-
- def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
- """
- Transform action from end-effector space to joint space and send to motors.
-
- Args:
- action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
- or a numpy array with [delta_x, delta_y, delta_z]
-
- Returns:
- The joint-space action that was sent to the motors
- """
-
- if not self.is_connected:
- raise DeviceNotConnectedError(f"{self} is not connected.")
-
- # Convert action to numpy array if not already
- if isinstance(action, dict):
- if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
- delta_ee = np.array(
- [
- action["delta_x"] * self.config.end_effector_step_sizes["x"],
- action["delta_y"] * self.config.end_effector_step_sizes["y"],
- action["delta_z"] * self.config.end_effector_step_sizes["z"],
- ],
- dtype=np.float32,
- )
- if "gripper" not in action:
- action["gripper"] = [1.0]
- action = np.append(delta_ee, action["gripper"])
- else:
- logger.warning(
- f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
- )
- action = np.zeros(4, dtype=np.float32)
-
- if self.current_joint_pos is None:
- # Read current joint positions
- current_joint_pos = self.bus.sync_read("Present_Position")
- self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
-
- # Calculate current end-effector position using forward kinematics
- if self.current_ee_pos is None:
- self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
-
- # Set desired end-effector position by adding delta
- desired_ee_pos = np.eye(4)
- desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
-
- # Add delta to position and clip to bounds
- desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
- if self.end_effector_bounds is not None:
- desired_ee_pos[:3, 3] = np.clip(
- desired_ee_pos[:3, 3],
- self.end_effector_bounds["min"],
- self.end_effector_bounds["max"],
- )
-
- # Compute inverse kinematics to get joint positions
- target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
- self.current_joint_pos, desired_ee_pos
- )
-
- # Create joint space action dictionary
- joint_action = {
- f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
- }
-
- # Handle gripper separately if included in action
- # Gripper delta action is in the range 0 - 2,
- # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
- joint_action["gripper.pos"] = np.clip(
- self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
- 5,
- self.config.max_gripper_pos,
- )
-
- self.current_ee_pos = desired_ee_pos.copy()
- self.current_joint_pos = target_joint_values_in_degrees.copy()
- self.current_joint_pos[-1] = joint_action["gripper.pos"]
-
- # Send joint space action to parent class
- return super().send_action(joint_action)
-
- def get_observation(self) -> dict[str, Any]:
- if not self.is_connected:
- raise DeviceNotConnectedError(f"{self} is not connected.")
-
- # Read arm position
- start = time.perf_counter()
- obs_dict = self.bus.sync_read("Present_Position")
- obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
- dt_ms = (time.perf_counter() - start) * 1e3
- logger.debug(f"{self} read state: {dt_ms:.1f}ms")
-
- # Capture images from cameras
- for cam_key, cam in self.cameras.items():
- start = time.perf_counter()
- obs_dict[cam_key] = cam.async_read()
- dt_ms = (time.perf_counter() - start) * 1e3
- logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
-
- return obs_dict
-
- def reset(self):
- self.current_ee_pos = None
- self.current_joint_pos = None
diff --git a/src/lerobot/robots/so101_follower/__init__.py b/src/lerobot/robots/so101_follower/__init__.py
index f6615b15b..9ff2baf45 100644
--- a/src/lerobot/robots/so101_follower/__init__.py
+++ b/src/lerobot/robots/so101_follower/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_so101_follower import SO101FollowerConfig
from .so101_follower import SO101Follower
diff --git a/src/lerobot/robots/so101_follower/config_so101_follower.py b/src/lerobot/robots/so101_follower/config_so101_follower.py
index be630e6ac..03c3530c2 100644
--- a/src/lerobot/robots/so101_follower/config_so101_follower.py
+++ b/src/lerobot/robots/so101_follower/config_so101_follower.py
@@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig):
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
- # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
- # the number of motors in your follower arms.
- max_relative_target: int | None = None
+ # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
+ # names to the max_relative_target value for that motor.
+ max_relative_target: float | dict[str, float] | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
diff --git a/src/lerobot/robots/so101_follower/so101.mdx b/src/lerobot/robots/so101_follower/so101.mdx
deleted file mode 100644
index c49807d93..000000000
--- a/src/lerobot/robots/so101_follower/so101.mdx
+++ /dev/null
@@ -1,381 +0,0 @@
-# SO-101
-
-In the steps below, we explain how to assemble our flagship robot, the SO-101.
-
-## Source the parts
-
-Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts.
-And advise if it's your first time printing or if you don't own a 3D printer.
-
-## Install LeRobot 🤗
-
-To install LeRobot, follow our [Installation Guide](./installation)
-
-In addition to these instructions, you need to install the Feetech SDK:
-```bash
-pip install -e ".[feetech]"
-```
-
-## Step-by-Step Assembly Instructions
-
-The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however, uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in the table below.
-
-| Leader-Arm Axis | Motor | Gear Ratio |
-|-----------------|:-------:|:----------:|
-| Base / Shoulder Pan | 1 | 1 / 191 |
-| Shoulder Lift | 2 | 1 / 345 |
-| Elbow Flex | 3 | 1 / 191 |
-| Wrist Flex | 4 | 1 / 147 |
-| Wrist Roll | 5 | 1 / 147 |
-| Gripper | 6 | 1 / 147 |
-
-### Clean Parts
-Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
-
-### Joint 1
-
-- Place the first motor into the base.
-- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
-- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
-- Install both motor horns, securing the top horn with a M3x6mm screw.
-- Attach the shoulder part.
-- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
-- Add the shoulder motor holder.
-
-
-
-
-
-### Joint 2
-
-- Slide the second motor in from the top.
-- Fasten the second motor with 4 M2x6mm screws.
-- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
-- Attach the upper arm with 4 M3x6mm screws on each side.
-
-
-
-
-
-### Joint 3
-
-- Insert motor 3 and fasten using 4 M2x6mm screws
-- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
-- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
-
-
-
-
-
-### Joint 4
-
-- Slide over motor holder 4.
-- Slide in motor 4.
-- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
-
-
-
-
-
-### Joint 5
-
-- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
-- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
-- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
-
-
-
-
-
-### Gripper / Handle
-
-
-
-
-- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
-- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
-- Attach the motor horns and again use a M3x6mm horn screw.
-- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
-
-
-
-
-
-
-
-
-- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
-- Attach the handle to motor 5 using 1 M2x6mm screw.
-- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
-- Attach the follower trigger with 4 M3x6mm screws.
-
-
-
-
-
-
-
-
-## Configure the motors
-
-### 1. Find the USB ports associated with each arm
-
-To find the port for each bus servo adapter, run this script:
-```bash
-python -m lerobot.find_port
-```
-
-
-
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
-Remove the USB cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/tty.usbmodem575E0032081
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm.
-
-
-
-
-On Linux, you might need to give access to the USB ports by running:
-```bash
-sudo chmod 666 /dev/ttyACM0
-sudo chmod 666 /dev/ttyACM1
-```
-
-Example output:
-
-```
-Finding all available ports for the MotorBus.
-['/dev/ttyACM0', '/dev/ttyACM1']
-Remove the usb cable from your MotorsBus and press Enter when done.
-
-[...Disconnect corresponding leader or follower arm and press Enter...]
-
-The port of this MotorsBus is /dev/ttyACM1
-Reconnect the USB cable.
-```
-
-Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm.
-
-
-
-
-### 2. Set the motors ids and baudrates
-
-Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate.
-
-To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once.
-
-If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match.
-
-The video below shows the sequence of steps for setting the motor ids.
-
-##### Setup motors video
-
-
-
-
-
-#### Follower
-
-Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter.
-
-
-
-
-```bash
-python -m lerobot.setup_motors \
- --robot.type=so101_follower \
- --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step
-```
-
-
-
-```python
-from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
-
-config = SO101FollowerConfig(
- port="/dev/tty.usbmodem585A0076841",
- id="my_awesome_follower_arm",
-)
-follower = SO101Follower(config)
-follower.setup_motors()
-```
-
-
-
-You should see the following instruction
-```bash
-Connect the controller board to the 'gripper' motor only and press enter.
-```
-
-As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor.
-
-
-Troubleshooting
-
- If you get an error at that point, check your cables and make sure they are plugged in properly:
-
-
Power supply
-
USB cable between your computer and the controller board
-
The 3-pin cable from the controller board to the motor
-
-
- If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
-
-
-You should then see the following message:
-```bash
-'gripper' motor id set to 6
-```
-
-Followed by the next instruction:
-```bash
-Connect the controller board to the 'wrist_roll' motor only and press enter.
-```
-
-You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one.
-
-Repeat the operation for each motor as instructed.
-
-> [!TIP]
-> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board.
-
-When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm.
-
-#### Leader
-Do the same steps for the leader arm.
-
-
-
-
-```bash
-python -m lerobot.setup_motors \
- --teleop.type=so101_leader \
- --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step
-```
-
-
-
-```python
-from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig
-
-config = SO101LeaderConfig(
- port="/dev/tty.usbmodem585A0076841",
- id="my_awesome_leader_arm",
-)
-leader = SO101Leader(config)
-leader.setup_motors()
-```
-
-
-
-## Calibrate
-
-Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
-The calibration process is very important because it allows a neural network trained on one robot to work on another.
-
-#### Follower
-
-Run the following command or API example to calibrate the follower arm:
-
-
-
-
-```bash
-python -m lerobot.calibrate \
- --robot.type=so101_follower \
- --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
- --robot.id=my_awesome_follower_arm # <- Give the robot a unique name
-```
-
-
-
-```python
-from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
-
-config = SO101FollowerConfig(
- port="/dev/tty.usbmodem585A0076891",
- id="my_awesome_follower_arm",
-)
-
-follower = SO101Follower(config)
-follower.connect(calibrate=False)
-follower.calibrate()
-follower.disconnect()
-```
-
-
-
-The video below shows how to perform the calibration. First you need to move the robot to the position where all joints are in the middle of their ranges. Then after pressing enter you have to move each joint through its full range of motion.
-
-##### Calibration video
-
-
-
-
-
-#### Leader
-
-Do the same steps to calibrate the leader arm, run the following command or API example:
-
-
-
-
-```bash
-python -m lerobot.calibrate \
- --teleop.type=so101_leader \
- --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot
- --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name
-```
-
-
-
-```python
-from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader
-
-config = SO101LeaderConfig(
- port="/dev/tty.usbmodem58760431551",
- id="my_awesome_leader_arm",
-)
-
-leader = SO101Leader(config)
-leader.connect(calibrate=False)
-leader.calibrate()
-leader.disconnect()
-```
-
-
-
-Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
-
-> [!TIP]
-> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
diff --git a/src/lerobot/robots/so101_follower/so101.mdx b/src/lerobot/robots/so101_follower/so101.mdx
new file mode 120000
index 000000000..27b892660
--- /dev/null
+++ b/src/lerobot/robots/so101_follower/so101.mdx
@@ -0,0 +1 @@
+../../../../docs/source/so101.mdx
\ No newline at end of file
diff --git a/src/lerobot/robots/so101_follower/so101_follower.py b/src/lerobot/robots/so101_follower/so101_follower.py
index 3ae3c3967..acfd4bd11 100644
--- a/src/lerobot/robots/so101_follower/so101_follower.py
+++ b/src/lerobot/robots/so101_follower/so101_follower.py
@@ -20,12 +20,12 @@ from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -92,6 +92,9 @@ class SO101Follower(Robot):
self.bus.connect()
if not self.is_calibrated and calibrate:
+ logger.info(
+ "Mismatch between calibration values in the motor and the calibration file or no calibration file found"
+ )
self.calibrate()
for cam in self.cameras.values():
@@ -105,6 +108,16 @@ class SO101Follower(Robot):
return self.bus.is_calibrated
def calibrate(self) -> None:
+ if self.calibration:
+ # self.calibration is not empty here
+ user_input = input(
+ f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
+ )
+ if user_input.strip().lower() != "c":
+ logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
+ self.bus.write_calibration(self.calibration)
+ return
+
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
@@ -144,6 +157,13 @@ class SO101Follower(Robot):
self.bus.write("I_Coefficient", motor, 0)
self.bus.write("D_Coefficient", motor, 32)
+ if motor == "gripper":
+ self.bus.write(
+ "Max_Torque_Limit", motor, 500
+ ) # 50% of the max torque limit to avoid burnout
+ self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
+ self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
+
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
diff --git a/src/lerobot/robots/stretch3/README.md b/src/lerobot/robots/stretch3/README.md
index 982e72571..027f12d65 100644
--- a/src/lerobot/robots/stretch3/README.md
+++ b/src/lerobot/robots/stretch3/README.md
@@ -5,16 +5,17 @@ This tutorial explains how to use [Stretch 3](https://hello-robot.com/stretch-3-
Familiarize yourself with Stretch by following its [tutorials](https://docs.hello-robot.com/0.3/getting_started/hello_robot/) (recommended).
To use LeRobot on Stretch, 3 options are available:
+
- [tethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup)
- [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup)
- ssh directly into Stretch (you will first need to install and configure openssh-server on stretch using one of the two above setups)
-
## Install LeRobot
On Stretch's CLI, follow these steps:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
+
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
@@ -24,6 +25,7 @@ rm ~/miniconda3/miniconda.sh
```
2. Comment out these lines in `~/.profile` (this can mess up paths used by conda and ~/.local/bin should already be in your PATH)
+
```
# set PATH so it includes user's private bin if it exists
if [ -d "$HOME/.local/bin" ] ; then
@@ -34,21 +36,25 @@ fi
3. Restart shell or `source ~/.bashrc`
4. Create and activate a fresh conda environment for lerobot
+
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
5. Clone LeRobot:
+
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
6. When using `miniconda`, install `ffmpeg` in your environment:
+
```bash
conda install ffmpeg -c conda-forge
```
7. Install LeRobot with stretch dependencies:
+
```bash
cd ~/lerobot && pip install -e ".[stretch]"
```
@@ -56,6 +62,7 @@ cd ~/lerobot && pip install -e ".[stretch]"
> **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.`
8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready:
+
```bash
stretch_system_check.py
```
@@ -63,6 +70,7 @@ stretch_system_check.py
> **Note:** You may need to free the "robot process" after booting Stretch by running `stretch_free_robot_process.py`. For more info this Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#turning-off-gamepad-teleoperation).
You should get something like this:
+
```bash
For use with S T R E T C H (R) from Hello Robot Inc.
---------------------------------------------------------------------
@@ -89,11 +97,13 @@ Serial Number = stretch-se3-3054
**Calibrate (Optional)**
Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command:
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=calibrate
```
+
This is equivalent to running `stretch_robot_home.py`
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
@@ -104,28 +114,33 @@ Before trying teleoperation, you need to activate the gamepad controller by pres
Now try out teleoperation (see above documentation to learn about the gamepad controls):
> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`.
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
--control.type=teleoperate
```
+
This is essentially the same as running `stretch_gamepad_teleop.py`
**Record a dataset**
Once you're familiar with the gamepad controls and after a bit of practice, you can try to record your first dataset with Stretch.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
+
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
+
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record one episode:
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
@@ -145,6 +160,7 @@ python lerobot/scripts/control_robot.py \
**Replay an episode**
Now try to replay this episode (make sure the robot's initial position is the same):
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=stretch \
@@ -154,8 +170,4 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
-Follow [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) to train a policy on your data and run inference on your robot. You will need to adapt the code for Stretch.
-
-> TODO(rcadene, aliberts): Add already setup environment and policy yaml configuration files
-
If you need help, please reach out on Discord in the channel `#stretch3-mobile-arm`.
diff --git a/src/lerobot/robots/stretch3/__init__.py b/src/lerobot/robots/stretch3/__init__.py
index e2a859cde..b3070bbd6 100644
--- a/src/lerobot/robots/stretch3/__init__.py
+++ b/src/lerobot/robots/stretch3/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .configuration_stretch3 import Stretch3RobotConfig
from .robot_stretch3 import Stretch3Robot
diff --git a/src/lerobot/robots/stretch3/configuration_stretch3.py b/src/lerobot/robots/stretch3/configuration_stretch3.py
index 9fcf8f742..c1226bf90 100644
--- a/src/lerobot/robots/stretch3/configuration_stretch3.py
+++ b/src/lerobot/robots/stretch3/configuration_stretch3.py
@@ -24,11 +24,6 @@ from ..config import RobotConfig
@RobotConfig.register_subclass("stretch3")
@dataclass
class Stretch3RobotConfig(RobotConfig):
- # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
- # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
- # the number of motors in your follower arms.
- max_relative_target: int | None = None
-
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
@@ -54,5 +49,3 @@ class Stretch3RobotConfig(RobotConfig):
),
}
)
-
- mock: bool = False
diff --git a/src/lerobot/robots/stretch3/robot_stretch3.py b/src/lerobot/robots/stretch3/robot_stretch3.py
index b907d6a3f..73df360b2 100644
--- a/src/lerobot/robots/stretch3/robot_stretch3.py
+++ b/src/lerobot/robots/stretch3/robot_stretch3.py
@@ -22,8 +22,8 @@ from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams
from lerobot.cameras.utils import make_cameras_from_configs
-from lerobot.constants import OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import get_nested_item
+from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from ..robot import Robot
from .configuration_stretch3 import Stretch3RobotConfig
@@ -164,10 +164,6 @@ class Stretch3Robot(Robot):
# TODO(aliberts): return action_sent when motion is limited
return action
- def print_logs(self) -> None:
- pass
- # TODO(aliberts): move robot-specific logs logic here
-
def teleop_safety_stop(self) -> None:
if self.teleop is not None:
self.teleop._safety_stop(robot=self)
diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py
index 435303c6e..aca5c8716 100644
--- a/src/lerobot/robots/utils.py
+++ b/src/lerobot/robots/utils.py
@@ -14,13 +14,16 @@
import logging
from pprint import pformat
+from typing import cast
-from lerobot.robots import RobotConfig
+from lerobot.utils.import_utils import make_device_from_device_class
+from .config import RobotConfig
from .robot import Robot
def make_robot_from_config(config: RobotConfig) -> Robot:
+ # TODO(Steven): Consider just using the make_device_from_device_class for all types
if config.type == "koch_follower":
from .koch_follower import KochFollower
@@ -29,10 +32,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .so100_follower import SO100Follower
return SO100Follower(config)
- elif config.type == "so100_follower_end_effector":
- from .so100_follower import SO100FollowerEndEffector
-
- return SO100FollowerEndEffector(config)
elif config.type == "so101_follower":
from .so101_follower import SO101Follower
@@ -49,16 +48,36 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
from .viperx import ViperX
return ViperX(config)
+ elif config.type == "hope_jr_hand":
+ from .hope_jr import HopeJrHand
+
+ return HopeJrHand(config)
+ elif config.type == "hope_jr_arm":
+ from .hope_jr import HopeJrArm
+
+ return HopeJrArm(config)
+ elif config.type == "bi_so100_follower":
+ from .bi_so100_follower import BiSO100Follower
+
+ return BiSO100Follower(config)
+ elif config.type == "reachy2":
+ from .reachy2 import Reachy2Robot
+
+ return Reachy2Robot(config)
elif config.type == "mock_robot":
from tests.mocks.mock_robot import MockRobot
return MockRobot(config)
else:
- raise ValueError(config.type)
+ try:
+ return cast(Robot, make_device_from_device_class(config))
+ except Exception as e:
+ raise ValueError(f"Error creating robot with config {config}: {e}") from e
+# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
def ensure_safe_goal_position(
- goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
+ goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float]
) -> dict[str, float]:
"""Caps relative action target magnitude for safety."""
diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md
index 445368e7a..2e8fc7289 100644
--- a/src/lerobot/robots/viperx/README.md
+++ b/src/lerobot/robots/viperx/README.md
@@ -4,12 +4,12 @@ This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.tro
Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer.
-
## Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
+
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
@@ -21,29 +21,34 @@ rm ~/miniconda3/miniconda.sh
2. Restart shell or `source ~/.bashrc`
3. Create and activate a fresh conda environment for lerobot
+
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
4. Clone LeRobot:
+
```bash
git clone https://github.com/huggingface/lerobot.git ~/lerobot
```
5. When using `miniconda`, install `ffmpeg` in your environment:
+
```bash
conda install ffmpeg -c conda-forge
```
6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense):
+
```bash
cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]"
```
## Teleoperate
-**/!\ FOR SAFETY, READ THIS /!\**
+\*\*/!\ FOR SAFETY, READ THIS /!\*\*
Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly:
+
1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms,
2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics.
@@ -58,7 +63,8 @@ python lerobot/scripts/control_robot.py \
--control.type=teleoperate
```
-By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
+By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`ViperXConfig`](./config_viperx.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line:
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
@@ -71,17 +77,20 @@ python lerobot/scripts/control_robot.py \
Once you're familiar with teleoperation, you can record your first dataset with Aloha.
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
+
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face repository name in a variable to run these commands:
+
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
Record 2 episodes and upload your dataset to the hub:
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
@@ -101,22 +110,25 @@ python lerobot/scripts/control_robot.py \
## Visualize a dataset
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
+
```bash
echo ${HF_USER}/aloha_test
```
-If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with:
+If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun):
+
```bash
-python -m lerobot.scripts.visualize_dataset_html \
- --repo-id ${HF_USER}/aloha_test
+lerobot-dataset-viz \
+ --repo-id ${HF_USER}/aloha_test --episode 0
```
## Replay an episode
-**/!\ FOR SAFETY, READ THIS /!\**
+\*\*/!\ FOR SAFETY, READ THIS /!\*\*
Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above.
Now try to replay the first episode on your robot:
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
@@ -129,9 +141,10 @@ python lerobot/scripts/control_robot.py \
## Train a policy
-To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
+To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
+
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--dataset.repo_id=${HF_USER}/aloha_test \
--policy.type=act \
--output_dir=outputs/train/act_aloha_test \
@@ -141,10 +154,11 @@ python -m lerobot.scripts.train \
```
Let's explain it:
+
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
-4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
-5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
+3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
+4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
@@ -153,6 +167,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/
## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../src/lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
+
```bash
python lerobot/scripts/control_robot.py \
--robot.type=aloha \
@@ -171,12 +186,11 @@ python lerobot/scripts/control_robot.py \
```
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
-1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
+
+1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`).
3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`.
## More
-Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation.
-
If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`.
diff --git a/src/lerobot/robots/viperx/__init__.py b/src/lerobot/robots/viperx/__init__.py
index 522d02f1c..bfba07fc7 100644
--- a/src/lerobot/robots/viperx/__init__.py
+++ b/src/lerobot/robots/viperx/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_viperx import ViperXConfig
from .viperx import ViperX
diff --git a/src/lerobot/robots/viperx/config_viperx.py b/src/lerobot/robots/viperx/config_viperx.py
index 4922f1d18..ed3876a9c 100644
--- a/src/lerobot/robots/viperx/config_viperx.py
+++ b/src/lerobot/robots/viperx/config_viperx.py
@@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig):
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
- # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
- # the number of motors in your follower arms.
+ # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
+ # names to the max_relative_target value for that motor.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
- max_relative_target: int | None = 5
+ max_relative_target: float | dict[str, float] = 5.0
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)
diff --git a/src/lerobot/robots/viperx/viperx.py b/src/lerobot/robots/viperx/viperx.py
index 881640cd5..31e99ffdb 100644
--- a/src/lerobot/robots/viperx/viperx.py
+++ b/src/lerobot/robots/viperx/viperx.py
@@ -18,13 +18,13 @@ from functools import cached_property
from typing import Any
from lerobot.cameras.utils import make_cameras_from_configs
-from lerobot.constants import OBS_STATE
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
+from lerobot.utils.constants import OBS_STATE
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
diff --git a/src/lerobot/scripts/display_sys_info.py b/src/lerobot/scripts/display_sys_info.py
deleted file mode 100644
index 4d3cc291f..000000000
--- a/src/lerobot/scripts/display_sys_info.py
+++ /dev/null
@@ -1,90 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Use this script to get a quick summary of your system config.
-It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
-"""
-
-import platform
-
-HAS_HF_HUB = True
-HAS_HF_DATASETS = True
-HAS_NP = True
-HAS_TORCH = True
-HAS_LEROBOT = True
-
-try:
- import huggingface_hub
-except ImportError:
- HAS_HF_HUB = False
-
-try:
- import datasets
-except ImportError:
- HAS_HF_DATASETS = False
-
-try:
- import numpy as np
-except ImportError:
- HAS_NP = False
-
-try:
- import torch
-except ImportError:
- HAS_TORCH = False
-
-try:
- import lerobot
-except ImportError:
- HAS_LEROBOT = False
-
-
-lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
-hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
-hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
-np_version = np.__version__ if HAS_NP else "N/A"
-
-torch_version = torch.__version__ if HAS_TORCH else "N/A"
-torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
-cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
-
-
-# TODO(aliberts): refactor into an actual command `lerobot env`
-def display_sys_info() -> dict:
- """Run this to get basic system info to help for tracking issues & bugs."""
- info = {
- "`lerobot` version": lerobot_version,
- "Platform": platform.platform(),
- "Python version": platform.python_version(),
- "Huggingface_hub version": hf_hub_version,
- "Dataset version": hf_datasets_version,
- "Numpy version": np_version,
- "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
- "Cuda version": cuda_version,
- "Using GPU in script?": "",
- # "Using distributed or parallel set-up in script?": "",
- }
- print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
- print(format_dict(info))
- return info
-
-
-def format_dict(d: dict) -> str:
- return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
-
-
-if __name__ == "__main__":
- display_sys_info()
diff --git a/src/lerobot/calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py
similarity index 92%
rename from src/lerobot/calibrate.py
rename to src/lerobot/scripts/lerobot_calibrate.py
index 37a9d5bdf..0f247caef 100644
--- a/src/lerobot/calibrate.py
+++ b/src/lerobot/scripts/lerobot_calibrate.py
@@ -18,7 +18,7 @@ Helper to recalibrate your device (robot or teleoperator).
Example:
```shell
-python -m lerobot.calibrate \
+lerobot-calibrate \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue
@@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
+ hope_jr,
koch_follower,
lekiwi,
make_robot_from_config,
@@ -45,11 +46,13 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
+ homunculus,
koch_leader,
make_teleoperator_from_config,
so100_leader,
so101_leader,
)
+from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.utils import init_logging
@@ -80,5 +83,10 @@ def calibrate(cfg: CalibrateConfig):
device.disconnect()
-if __name__ == "__main__":
+def main():
+ register_third_party_devices()
calibrate()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/visualize_dataset.py b/src/lerobot/scripts/lerobot_dataset_viz.py
similarity index 90%
rename from src/lerobot/scripts/visualize_dataset.py
rename to src/lerobot/scripts/lerobot_dataset_viz.py
index 37db66ddf..55708d9a9 100644
--- a/src/lerobot/scripts/visualize_dataset.py
+++ b/src/lerobot/scripts/lerobot_dataset_viz.py
@@ -29,14 +29,14 @@ Examples:
- Visualize data stored on a local machine:
```
-local$ python -m lerobot.scripts.visualize_dataset \
+local$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0
```
- Visualize data stored on a distant machine with a local viewer:
```
-distant$ python -m lerobot.scripts.visualize_dataset \
+distant$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--save 1 \
@@ -50,7 +50,7 @@ local$ rerun lerobot_pusht_episode_0.rrd
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
-distant$ python -m lerobot.scripts.visualize_dataset \
+distant$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
@@ -65,8 +65,8 @@ import argparse
import gc
import logging
import time
+from collections.abc import Iterator
from pathlib import Path
-from typing import Iterator
import numpy as np
import rerun as rr
@@ -75,12 +75,13 @@ import torch.utils.data
import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
- from_idx = dataset.episode_data_index["from"][episode_index].item()
- to_idx = dataset.episode_data_index["to"][episode_index].item()
+ from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
+ to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
@@ -156,20 +157,20 @@ def visualize_dataset(
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
# display each dimension of action space (e.g. actuators command)
- if "action" in batch:
- for dim_idx, val in enumerate(batch["action"][i]):
- rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
+ if ACTION in batch:
+ for dim_idx, val in enumerate(batch[ACTION][i]):
+ rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space)
- if "observation.state" in batch:
- for dim_idx, val in enumerate(batch["observation.state"][i]):
+ if OBS_STATE in batch:
+ for dim_idx, val in enumerate(batch[OBS_STATE][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
- if "next.done" in batch:
- rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
+ if DONE in batch:
+ rr.log(DONE, rr.Scalar(batch[DONE][i].item()))
- if "next.reward" in batch:
- rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
+ if REWARD in batch:
+ rr.log(REWARD, rr.Scalar(batch[REWARD][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
@@ -283,7 +284,7 @@ def main():
tolerance_s = kwargs.pop("tolerance_s")
logging.info("Loading dataset")
- dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
+ dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
visualize_dataset(dataset, **vars(args))
diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/lerobot_eval.py
similarity index 62%
rename from src/lerobot/scripts/eval.py
rename to src/lerobot/scripts/lerobot_eval.py
index d85ac27b3..d45be5c42 100644
--- a/src/lerobot/scripts/eval.py
+++ b/src/lerobot/scripts/lerobot_eval.py
@@ -21,7 +21,7 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di
for 10 episodes.
```
-python -m lerobot.scripts.eval \
+lerobot-eval \
--policy.path=lerobot/diffusion_pusht \
--env.type=pusht \
--eval.batch_size=10 \
@@ -32,7 +32,7 @@ python -m lerobot.scripts.eval \
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
```
-python -m lerobot.scripts.eval \
+lerobot-eval \
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
--env.type=pusht \
--eval.batch_size=10 \
@@ -46,16 +46,20 @@ Note that in both examples, the repo/folder should contain at least `config.json
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
"""
+import concurrent.futures as cf
import json
import logging
import threading
import time
+from collections import defaultdict
+from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
+from functools import partial
from pathlib import Path
from pprint import pformat
-from typing import Callable
+from typing import Any, TypedDict
import einops
import gymnasium as gym
@@ -68,10 +72,16 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
-from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
-from lerobot.policies.factory import make_policy
+from lerobot.envs.utils import (
+ add_envs_task,
+ check_env_attributes_and_types,
+ close_envs,
+ preprocess_observation,
+)
+from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.utils import get_device_from_parameters
+from lerobot.processor import PolicyAction, PolicyProcessorPipeline
+from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -84,6 +94,8 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
+ preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -120,7 +132,6 @@ def rollout(
The dictionary described above.
"""
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
- device = get_device_from_parameters(policy)
# Reset the policy and environments.
policy.reset()
@@ -145,29 +156,26 @@ def rollout(
leave=False,
)
check_env_attributes_and_types(env)
- while not np.all(done):
+ while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
- observation = {
- key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
- }
-
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
-
+ observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
+ action = postprocessor(action)
# Convert to CPU / numpy.
- action = action.to("cpu").numpy()
- assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
+ action_numpy: np.ndarray = action.to("cpu").numpy()
+ assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
- observation, reward, terminated, truncated, info = env.step(action)
+ observation, reward, terminated, truncated, info = env.step(action_numpy)
if render_callback is not None:
render_callback(env)
@@ -179,9 +187,14 @@ def rollout(
successes = [False] * env.num_envs
# Keep track of which environments are done so far.
+ # Mark the episode as done if we reach the maximum step limit.
+ # This ensures that the rollout always terminates cleanly at `max_steps`,
+ # and allows logging/saving (e.g., videos) to be triggered consistently.
done = terminated | truncated | done
+ if step + 1 == max_steps:
+ done = np.ones_like(done, dtype=bool)
- all_actions.append(torch.from_numpy(action))
+ all_actions.append(torch.from_numpy(action_numpy))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
@@ -200,7 +213,7 @@ def rollout(
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
ret = {
- "action": torch.stack(all_actions, dim=1),
+ ACTION: torch.stack(all_actions, dim=1),
"reward": torch.stack(all_rewards, dim=1),
"success": torch.stack(all_successes, dim=1),
"done": torch.stack(all_dones, dim=1),
@@ -209,7 +222,7 @@ def rollout(
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
- ret["observation"] = stacked_observations
+ ret[OBS_STR] = stacked_observations
if hasattr(policy, "use_original_modules"):
policy.use_original_modules()
@@ -220,6 +233,8 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
+ preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
@@ -296,8 +311,10 @@ def eval_policy(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
- env,
- policy,
+ env=env,
+ policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
@@ -423,28 +440,28 @@ def _compile_episode_data(
"""
ep_dicts = []
total_frames = 0
- for ep_ix in range(rollout_data["action"].shape[0]):
+ for ep_ix in range(rollout_data[ACTION].shape[0]):
# + 2 to include the first done frame and the last observation frame.
num_frames = done_indices[ep_ix].item() + 2
total_frames += num_frames
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
- "action": rollout_data["action"][ep_ix, : num_frames - 1],
+ ACTION: rollout_data[ACTION][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
- "next.done": rollout_data["done"][ep_ix, : num_frames - 1],
+ DONE: rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
- "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
+ REWARD: rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
}
# For the last observation frame, all other keys will just be copy padded.
for k in ep_dict:
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
- for key in rollout_data["observation"]:
- ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
+ for key in rollout_data[OBS_STR]:
+ ep_dict[key] = rollout_data[OBS_STR][key][ep_ix, :num_frames]
ep_dicts.append(ep_dict)
@@ -471,7 +488,7 @@ def eval_main(cfg: EvalPipelineConfig):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.")
- env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
+ envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
@@ -479,28 +496,259 @@ def eval_main(cfg: EvalPipelineConfig):
cfg=cfg.policy,
env_cfg=cfg.env,
)
- policy.eval()
+ policy.eval()
+ preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=cfg.policy,
+ pretrained_path=cfg.policy.pretrained_path,
+ # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
+ preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
+ )
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
- info = eval_policy(
- env,
- policy,
- cfg.eval.n_episodes,
+ info = eval_policy_all(
+ envs=envs,
+ policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ n_episodes=cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
+ max_parallel_tasks=cfg.env.max_parallel_tasks,
)
- print(info["aggregated"])
+ print("Overall Aggregated Metrics:")
+ print(info["overall"])
+
+ # Print per-suite stats
+ for task_group, task_group_info in info.items():
+ print(f"\nAggregated Metrics for {task_group}:")
+ print(task_group_info)
+ # Close all vec envs
+ close_envs(envs)
# Save info
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2)
- env.close()
-
logging.info("End of eval")
-if __name__ == "__main__":
+# ---- typed payload returned by one task eval ----
+class TaskMetrics(TypedDict):
+ sum_rewards: list[float]
+ max_rewards: list[float]
+ successes: list[bool]
+ video_paths: list[str]
+
+
+ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
+
+
+def eval_one(
+ env: gym.vector.VectorEnv,
+ *,
+ policy: PreTrainedPolicy,
+ preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
+ n_episodes: int,
+ max_episodes_rendered: int,
+ videos_dir: Path | None,
+ return_episode_data: bool,
+ start_seed: int | None,
+) -> TaskMetrics:
+ """Evaluates one task_id of one suite using the provided vec env."""
+
+ task_videos_dir = videos_dir
+
+ task_result = eval_policy(
+ env=env,
+ policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ n_episodes=n_episodes,
+ max_episodes_rendered=max_episodes_rendered,
+ videos_dir=task_videos_dir,
+ return_episode_data=return_episode_data,
+ start_seed=start_seed,
+ )
+
+ per_episode = task_result["per_episode"]
+ return TaskMetrics(
+ sum_rewards=[ep["sum_reward"] for ep in per_episode],
+ max_rewards=[ep["max_reward"] for ep in per_episode],
+ successes=[ep["success"] for ep in per_episode],
+ video_paths=task_result.get("video_paths", []),
+ )
+
+
+def run_one(
+ task_group: str,
+ task_id: int,
+ env,
+ *,
+ policy,
+ preprocessor,
+ postprocessor,
+ n_episodes: int,
+ max_episodes_rendered: int,
+ videos_dir: Path | None,
+ return_episode_data: bool,
+ start_seed: int | None,
+):
+ """
+ Run eval_one for a single (task_group, task_id, env).
+ Returns (task_group, task_id, task_metrics_dict).
+ This function is intentionally module-level to make it easy to test.
+ """
+ task_videos_dir = None
+ if videos_dir is not None:
+ task_videos_dir = videos_dir / f"{task_group}_{task_id}"
+ task_videos_dir.mkdir(parents=True, exist_ok=True)
+
+ # Call the existing eval_one (assumed to return TaskMetrics-like dict)
+ metrics = eval_one(
+ env,
+ policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ n_episodes=n_episodes,
+ max_episodes_rendered=max_episodes_rendered,
+ videos_dir=task_videos_dir,
+ return_episode_data=return_episode_data,
+ start_seed=start_seed,
+ )
+ # ensure we always provide video_paths key to simplify accumulation
+ if max_episodes_rendered > 0:
+ metrics.setdefault("video_paths", [])
+ return task_group, task_id, metrics
+
+
+def eval_policy_all(
+ envs: dict[str, dict[int, gym.vector.VectorEnv]],
+ policy,
+ preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
+ n_episodes: int,
+ *,
+ max_episodes_rendered: int = 0,
+ videos_dir: Path | None = None,
+ return_episode_data: bool = False,
+ start_seed: int | None = None,
+ max_parallel_tasks: int = 1,
+) -> dict:
+ """
+ Evaluate a nested `envs` dict: {task_group: {task_id: vec_env}}.
+ This implementation flattens tasks, runs them sequentially or via ThreadPoolExecutor,
+ accumulates per-group and overall statistics, and returns the same aggregate metrics
+ schema as the single-env evaluator (avg_sum_reward / avg_max_reward / pc_success / timings)
+ plus per-task infos.
+ """
+ start_t = time.time()
+
+ # Flatten envs into list of (task_group, task_id, env)
+ tasks = [(tg, tid, vec) for tg, group in envs.items() for tid, vec in group.items()]
+
+ # accumulators: track metrics at both per-group level and across all groups
+ group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
+ overall: dict[str, list] = {k: [] for k in ACC_KEYS}
+ per_task_infos: list[dict] = []
+
+ # small inline helper to accumulate one task's metrics into accumulators
+ def _accumulate_to(group: str, metrics: dict):
+ # metrics expected to contain 'sum_rewards', 'max_rewards', 'successes', optionally 'video_paths'
+ # but eval_one may store per-episode lists; we assume metrics uses scalars averaged per task as before.
+ # To be robust, accept scalars or lists.
+ def _append(key, value):
+ if value is None:
+ return
+ if isinstance(value, list):
+ group_acc[group][key].extend(value)
+ overall[key].extend(value)
+ else:
+ group_acc[group][key].append(value)
+ overall[key].append(value)
+
+ _append("sum_rewards", metrics.get("sum_rewards"))
+ _append("max_rewards", metrics.get("max_rewards"))
+ _append("successes", metrics.get("successes"))
+ # video_paths is list-like
+ paths = metrics.get("video_paths", [])
+ if paths:
+ group_acc[group]["video_paths"].extend(paths)
+ overall["video_paths"].extend(paths)
+
+ # Choose runner (sequential vs threaded)
+ task_runner = partial(
+ run_one,
+ policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ n_episodes=n_episodes,
+ max_episodes_rendered=max_episodes_rendered,
+ videos_dir=videos_dir,
+ return_episode_data=return_episode_data,
+ start_seed=start_seed,
+ )
+
+ if max_parallel_tasks <= 1:
+ # sequential path (single accumulator path on the main thread)
+ # NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
+ for task_group, task_id, env in tasks:
+ tg, tid, metrics = task_runner(task_group, task_id, env)
+ _accumulate_to(tg, metrics)
+ per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
+ else:
+ # threaded path: submit all tasks, consume completions on main thread and accumulate there
+ with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
+ fut2meta = {}
+ for task_group, task_id, env in tasks:
+ fut = executor.submit(task_runner, task_group, task_id, env)
+ fut2meta[fut] = (task_group, task_id)
+ for fut in cf.as_completed(fut2meta):
+ tg, tid, metrics = fut.result()
+ _accumulate_to(tg, metrics)
+ per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
+
+ # compute aggregated metrics helper (robust to lists/scalars)
+ def _agg_from_list(xs):
+ if not xs:
+ return float("nan")
+ arr = np.array(xs, dtype=float)
+ return float(np.nanmean(arr))
+
+ # compute per-group aggregates
+ groups_aggregated = {}
+ for group, acc in group_acc.items():
+ groups_aggregated[group] = {
+ "avg_sum_reward": _agg_from_list(acc["sum_rewards"]),
+ "avg_max_reward": _agg_from_list(acc["max_rewards"]),
+ "pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"),
+ "n_episodes": len(acc["sum_rewards"]),
+ "video_paths": list(acc["video_paths"]),
+ }
+
+ # overall aggregates
+ overall_agg = {
+ "avg_sum_reward": _agg_from_list(overall["sum_rewards"]),
+ "avg_max_reward": _agg_from_list(overall["max_rewards"]),
+ "pc_success": _agg_from_list(overall["successes"]) * 100 if overall["successes"] else float("nan"),
+ "n_episodes": len(overall["sum_rewards"]),
+ "eval_s": time.time() - start_t,
+ "eval_ep_s": (time.time() - start_t) / max(1, len(overall["sum_rewards"])),
+ "video_paths": list(overall["video_paths"]),
+ }
+
+ return {
+ "per_task": per_task_infos,
+ "per_group": groups_aggregated,
+ "overall": overall_agg,
+ }
+
+
+def main():
init_logging()
eval_main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py
similarity index 94%
rename from src/lerobot/find_cameras.py
rename to src/lerobot/scripts/lerobot_find_cameras.py
index aff2f8c19..e17dca805 100644
--- a/src/lerobot/find_cameras.py
+++ b/src/lerobot/scripts/lerobot_find_cameras.py
@@ -20,11 +20,11 @@ Helper to find the camera devices available in your system.
Example:
```shell
-python -m lerobot.find_cameras
+lerobot-find-cameras
```
"""
-# NOTE(Steven): RealSense can also be identified/opened as OpenCV cameras. If you know the camera is a RealSense, use the `lerobot.find_cameras realsense` flag to avoid confusion.
+# NOTE(Steven): RealSense can also be identified/opened as OpenCV cameras. If you know the camera is a RealSense, use the `lerobot-find-cameras realsense` flag to avoid confusion.
# NOTE(Steven): macOS cameras sometimes report different FPS at init time, not an issue here as we don't specify FPS when opening the cameras, but the information displayed might not be truthful.
import argparse
@@ -32,7 +32,7 @@ import concurrent.futures
import logging
import time
from pathlib import Path
-from typing import Any, Dict, List
+from typing import Any
import numpy as np
from PIL import Image
@@ -46,14 +46,14 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
logger = logging.getLogger(__name__)
-def find_all_opencv_cameras() -> List[Dict[str, Any]]:
+def find_all_opencv_cameras() -> list[dict[str, Any]]:
"""
Finds all available OpenCV cameras plugged into the system.
Returns:
A list of all available OpenCV cameras with their metadata.
"""
- all_opencv_cameras_info: List[Dict[str, Any]] = []
+ all_opencv_cameras_info: list[dict[str, Any]] = []
logger.info("Searching for OpenCV cameras...")
try:
opencv_cameras = OpenCVCamera.find_cameras()
@@ -66,14 +66,14 @@ def find_all_opencv_cameras() -> List[Dict[str, Any]]:
return all_opencv_cameras_info
-def find_all_realsense_cameras() -> List[Dict[str, Any]]:
+def find_all_realsense_cameras() -> list[dict[str, Any]]:
"""
Finds all available RealSense cameras plugged into the system.
Returns:
A list of all available RealSense cameras with their metadata.
"""
- all_realsense_cameras_info: List[Dict[str, Any]] = []
+ all_realsense_cameras_info: list[dict[str, Any]] = []
logger.info("Searching for RealSense cameras...")
try:
realsense_cameras = RealSenseCamera.find_cameras()
@@ -88,7 +88,7 @@ def find_all_realsense_cameras() -> List[Dict[str, Any]]:
return all_realsense_cameras_info
-def find_and_print_cameras(camera_type_filter: str | None = None) -> List[Dict[str, Any]]:
+def find_and_print_cameras(camera_type_filter: str | None = None) -> list[dict[str, Any]]:
"""
Finds available cameras based on an optional filter and prints their information.
@@ -99,7 +99,7 @@ def find_and_print_cameras(camera_type_filter: str | None = None) -> List[Dict[s
Returns:
A list of all available cameras matching the filter, with their metadata.
"""
- all_cameras_info: List[Dict[str, Any]] = []
+ all_cameras_info: list[dict[str, Any]] = []
if camera_type_filter:
camera_type_filter = camera_type_filter.lower()
@@ -153,7 +153,7 @@ def save_image(
logger.error(f"Failed to save image for camera {camera_identifier} (type {camera_type}): {e}")
-def create_camera_instance(cam_meta: Dict[str, Any]) -> Dict[str, Any] | None:
+def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None:
"""Create and connect to a camera instance based on metadata."""
cam_type = cam_meta.get("type")
cam_id = cam_meta.get("id")
@@ -190,7 +190,7 @@ def create_camera_instance(cam_meta: Dict[str, Any]) -> Dict[str, Any] | None:
def process_camera_image(
- cam_dict: Dict[str, Any], output_dir: Path, current_time: float
+ cam_dict: dict[str, Any], output_dir: Path, current_time: float
) -> concurrent.futures.Future | None:
"""Capture and process an image from a single camera."""
cam = cam_dict["instance"]
@@ -216,7 +216,7 @@ def process_camera_image(
return None
-def cleanup_cameras(cameras_to_use: List[Dict[str, Any]]):
+def cleanup_cameras(cameras_to_use: list[dict[str, Any]]):
"""Disconnect all cameras."""
logger.info(f"Disconnecting {len(cameras_to_use)} cameras...")
for cam_dict in cameras_to_use:
@@ -286,7 +286,7 @@ def save_images_from_all_cameras(
print(f"Image capture finished. Images saved to {output_dir}")
-if __name__ == "__main__":
+def main():
parser = argparse.ArgumentParser(
description="Unified camera utility script for listing cameras and capturing images."
)
@@ -313,3 +313,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
save_images_from_all_cameras(**vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py
similarity index 93%
rename from src/lerobot/scripts/find_joint_limits.py
rename to src/lerobot/scripts/lerobot_find_joint_limits.py
index f7e07514f..07d57a760 100644
--- a/src/lerobot/scripts/find_joint_limits.py
+++ b/src/lerobot/scripts/lerobot_find_joint_limits.py
@@ -20,13 +20,13 @@ Simple script to control a robot from teleoperation.
Example:
```shell
-python -m lerobot.scripts.server.find_joint_limits \
- --robot.type=so100_follower \
- --robot.port=/dev/tty.usbmodem58760431541 \
- --robot.id=black \
- --teleop.type=so100_leader \
- --teleop.port=/dev/tty.usbmodem58760431551 \
- --teleop.id=blue
+lerobot-find-joint-limits \
+ --robot.type=so100_follower \
+ --robot.port=/dev/tty.usbmodem58760431541 \
+ --robot.id=black \
+ --teleop.type=so100_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \
+ --teleop.id=blue
```
"""
@@ -117,5 +117,9 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
busy_wait(0.01)
-if __name__ == "__main__":
+def main():
find_joint_and_ee_bounds()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/find_port.py b/src/lerobot/scripts/lerobot_find_port.py
similarity index 97%
rename from src/lerobot/find_port.py
rename to src/lerobot/scripts/lerobot_find_port.py
index cf0282507..e32b9cb99 100644
--- a/src/lerobot/find_port.py
+++ b/src/lerobot/scripts/lerobot_find_port.py
@@ -18,7 +18,7 @@ Helper to find the USB port associated with your MotorsBus.
Example:
```shell
-python -m lerobot.find_port
+lerobot-find-port
```
"""
@@ -61,5 +61,9 @@ def find_port():
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
-if __name__ == "__main__":
+def main():
find_port()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/visualize_image_transforms.py b/src/lerobot/scripts/lerobot_imgtransform_viz.py
similarity index 97%
rename from src/lerobot/scripts/visualize_image_transforms.py
rename to src/lerobot/scripts/lerobot_imgtransform_viz.py
index 14caf89df..bc13f0508 100644
--- a/src/lerobot/scripts/visualize_image_transforms.py
+++ b/src/lerobot/scripts/lerobot_imgtransform_viz.py
@@ -20,10 +20,10 @@ Additionally, each individual transform can be visualized separately as well as
Example:
```bash
-python -m lerobot.scripts.visualize_image_transforms \
- --repo_id=lerobot/pusht \
- --episodes='[0]' \
- --image_transforms.enable=True
+lerobot-imgtransform-viz \
+ --repo_id=lerobot/pusht \
+ --episodes='[0]' \
+ --image_transforms.enable=True
```
"""
@@ -126,5 +126,9 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
-if __name__ == "__main__":
+def main():
visualize_image_transforms()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/lerobot_info.py b/src/lerobot/scripts/lerobot_info.py
new file mode 100644
index 000000000..9b49cad18
--- /dev/null
+++ b/src/lerobot/scripts/lerobot_info.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Use this script to get a quick summary of your system config.
+It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
+
+Example:
+
+```shell
+lerobot-info
+```
+"""
+
+import importlib
+import platform
+
+
+def get_package_version(package_name: str) -> str:
+ """Get the version of a package if it exists, otherwise return 'N/A'."""
+ try:
+ module = importlib.import_module(package_name)
+ return getattr(module, "__version__", "Installed (version not found)")
+ except ImportError:
+ return "N/A"
+
+
+def get_sys_info() -> dict:
+ """Run this to get basic system info to help for tracking issues & bugs."""
+ # General package versions
+ info = {
+ "lerobot version": get_package_version("lerobot"),
+ "Platform": platform.platform(),
+ "Python version": platform.python_version(),
+ "Huggingface Hub version": get_package_version("huggingface_hub"),
+ "Datasets version": get_package_version("datasets"),
+ "Numpy version": get_package_version("numpy"),
+ }
+
+ # PyTorch and GPU specific information
+ torch_version = "N/A"
+ torch_cuda_available = "N/A"
+ cuda_version = "N/A"
+ gpu_model = "N/A"
+ try:
+ import torch
+
+ torch_version = torch.__version__
+ torch_cuda_available = torch.cuda.is_available()
+ if torch_cuda_available:
+ cuda_version = torch.version.cuda
+ # Gets the name of the first available GPU
+ gpu_model = torch.cuda.get_device_name(0)
+ except ImportError:
+ # If torch is not installed, the default "N/A" values will be used.
+ pass
+
+ info.update(
+ {
+ "PyTorch version": torch_version,
+ "Is PyTorch built with CUDA support?": torch_cuda_available,
+ "Cuda version": cuda_version,
+ "GPU model": gpu_model,
+ "Using GPU in script?": "",
+ }
+ )
+
+ return info
+
+
+def format_dict_for_markdown(d: dict) -> str:
+ """Formats a dictionary into a markdown-friendly bulleted list."""
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()])
+
+
+def main():
+ system_info = get_sys_info()
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
+ print(format_dict_for_markdown(system_info))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/record.py b/src/lerobot/scripts/lerobot_record.py
similarity index 52%
rename from src/lerobot/record.py
rename to src/lerobot/scripts/lerobot_record.py
index 9cfbcad2b..f233aef38 100644
--- a/src/lerobot/record.py
+++ b/src/lerobot/scripts/lerobot_record.py
@@ -18,14 +18,15 @@ Records a dataset. Actions for the robot can be either generated by teleoperatio
Example:
```shell
-python -m lerobot.record \
+lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
- --robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
+ --robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--robot.id=black \
- --dataset.repo_id=aliberts/record-test \
+ --dataset.repo_id=/ \
--dataset.num_episodes=2 \
--dataset.single_task="Grab the cube" \
+ --display_data=true
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
# --teleop.type=so100_leader \
# --teleop.port=/dev/tty.usbmodem58760431551 \
@@ -33,14 +34,36 @@ python -m lerobot.record \
# <- Policy optional if you want to record with a policy \
# --policy.path=${HF_USER}/my_policy \
```
+
+Example recording with bimanual so100:
+```shell
+lerobot-record \
+ --robot.type=bi_so100_follower \
+ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
+ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
+ --robot.id=bimanual_follower \
+ --robot.cameras='{
+ left: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
+ top: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30},
+ right: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30}
+ }' \
+ --teleop.type=bi_so100_leader \
+ --teleop.left_arm_port=/dev/tty.usbmodem5A460828611 \
+ --teleop.right_arm_port=/dev/tty.usbmodem5A460826981 \
+ --teleop.id=bimanual_leader \
+ --display_data=true \
+ --dataset.repo_id=${HF_USER}/bimanual-so100-handover-cube \
+ --dataset.num_episodes=25 \
+ --dataset.single_task="Grab and handover the red cube to the other arm"
+```
"""
import logging
import time
-from dataclasses import asdict, dataclass
+from dataclasses import asdict, dataclass, field
from pathlib import Path
from pprint import pformat
-from typing import List
+from typing import Any
from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
@@ -51,12 +74,25 @@ from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
-from lerobot.policies.factory import make_policy
+from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
+from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
+from lerobot.datasets.video_utils import VideoEncodingManager
+from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.processor import (
+ PolicyAction,
+ PolicyProcessorPipeline,
+ RobotAction,
+ RobotObservation,
+ RobotProcessorPipeline,
+ make_default_processors,
+)
+from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
+ bi_so100_follower,
+ hope_jr,
koch_follower,
make_robot_from_config,
so100_follower,
@@ -65,12 +101,15 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
+ bi_so100_leader,
+ homunculus,
koch_leader,
make_teleoperator_from_config,
so100_leader,
so101_leader,
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
+from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import (
init_keyboard_listener,
is_headless,
@@ -78,13 +117,14 @@ from lerobot.utils.control_utils import (
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
+from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
get_safe_torch_device,
init_logging,
log_say,
)
-from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
+from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@dataclass
@@ -120,6 +160,11 @@ class DatasetRecordConfig:
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
+ # Number of episodes to record before batch encoding videos
+ # Set to 1 for immediate encoding (default behavior), or higher for batched encoding
+ video_encoding_batch_size: int = 1
+ # Rename map for the observation to override the image and state keys
+ rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
if self.single_task is None:
@@ -166,14 +211,55 @@ class RecordConfig:
return ["policy"]
+""" --------------- record_loop() data flow --------------------------
+ [ Robot ]
+ V
+ [ robot.get_observation() ] ---> raw_obs
+ V
+ [ robot_observation_processor ] ---> processed_obs
+ V
+ .-----( ACTION LOGIC )------------------.
+ V V
+ [ From Teleoperator ] [ From Policy ]
+ | |
+ | [teleop.get_action] -> raw_action | [predict_action]
+ | | | |
+ | V | V
+ | [teleop_action_processor] | |
+ | | | |
+ '---> processed_teleop_action '---> processed_policy_action
+ | |
+ '-------------------------.-------------'
+ V
+ [ robot_action_processor ] --> robot_action_to_send
+ V
+ [ robot.send_action() ] -- (Robot Executes)
+ V
+ ( Save to Dataset )
+ V
+ ( Rerun Log / Loop Wait )
+"""
+
+
@safe_stop_image_writer
def record_loop(
robot: Robot,
events: dict,
fps: int,
+ teleop_action_processor: RobotProcessorPipeline[
+ tuple[RobotAction, RobotObservation], RobotAction
+ ], # runs after teleop
+ robot_action_processor: RobotProcessorPipeline[
+ tuple[RobotAction, RobotObservation], RobotAction
+ ], # runs before robot
+ robot_observation_processor: RobotProcessorPipeline[
+ RobotObservation, RobotObservation
+ ], # runs after robot
dataset: LeRobotDataset | None = None,
- teleop: Teleoperator | List[Teleoperator] | None = None,
+ teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
+ preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
+ postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
@@ -188,7 +274,10 @@ def record_loop(
(
t
for t in teleop
- if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
+ if isinstance(
+ t,
+ (so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
+ )
),
None,
)
@@ -198,9 +287,11 @@ def record_loop(
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
)
- # if policy is given it needs cleaning up
- if policy is not None:
+ # Reset policy and processor if they are provided
+ if policy is not None and preprocessor is not None and postprocessor is not None:
policy.reset()
+ preprocessor.reset()
+ postprocessor.reset()
timestamp = 0
start_episode_t = time.perf_counter()
@@ -211,32 +302,46 @@ def record_loop(
events["exit_early"] = False
break
- observation = robot.get_observation()
+ # Get robot observation
+ obs = robot.get_observation()
+
+ # Applies a pipeline to the raw robot observation, default is IdentityProcessor
+ obs_processed = robot_observation_processor(obs)
if policy is not None or dataset is not None:
- observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
+ observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
- if policy is not None:
+ # Get action from either policy or teleop
+ if policy is not None and preprocessor is not None and postprocessor is not None:
action_values = predict_action(
- observation_frame,
- policy,
- get_safe_torch_device(policy.config.device),
- policy.config.use_amp,
+ observation=observation_frame,
+ policy=policy,
+ device=get_safe_torch_device(policy.config.device),
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
- action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
+
+ action_names = dataset.features[ACTION]["names"]
+ act_processed_policy: RobotAction = {
+ f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
+ }
+
elif policy is None and isinstance(teleop, Teleoperator):
- action = teleop.get_action()
+ act = teleop.get_action()
+
+ # Applies a pipeline to the raw teleop action, default is IdentityProcessor
+ act_processed_teleop = teleop_action_processor((act, obs))
+
elif policy is None and isinstance(teleop, list):
- # TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
arm_action = teleop_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
-
keyboard_action = teleop_keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_action)
-
- action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
+ act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
+ act_processed_teleop = teleop_action_processor((act, obs))
else:
logging.info(
"No policy or teleoperator provided, skipping action generation."
@@ -245,17 +350,28 @@ def record_loop(
)
continue
- # Action can eventually be clipped using `max_relative_target`,
- # so action actually sent is saved in the dataset.
- sent_action = robot.send_action(action)
+ # Applies a pipeline to the action, default is IdentityProcessor
+ if policy is not None and act_processed_policy is not None:
+ action_values = act_processed_policy
+ robot_action_to_send = robot_action_processor((act_processed_policy, obs))
+ else:
+ action_values = act_processed_teleop
+ robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
+ # Send action to robot
+ # Action can eventually be clipped using `max_relative_target`,
+ # so action actually sent is saved in the dataset. action = postprocessor.process(action)
+ # TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
+ _sent_action = robot.send_action(robot_action_to_send)
+
+ # Write to dataset
if dataset is not None:
- action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
- frame = {**observation_frame, **action_frame}
- dataset.add_frame(frame, task=single_task)
+ action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
+ frame = {**observation_frame, **action_frame, "task": single_task}
+ dataset.add_frame(frame)
if display_data:
- log_rerun_data(observation, action)
+ log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
@@ -268,19 +384,33 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
- _init_rerun(session_name="recording")
+ init_rerun(session_name="recording")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
- action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
- obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
- dataset_features = {**action_features, **obs_features}
+ teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
+
+ dataset_features = combine_feature_dicts(
+ aggregate_pipeline_dataset_features(
+ pipeline=teleop_action_processor,
+ initial_features=create_initial_features(
+ action=robot.action_features
+ ), # TODO(steven, pepijn): in future this should be come from teleop or policy
+ use_videos=cfg.dataset.video,
+ ),
+ aggregate_pipeline_dataset_features(
+ pipeline=robot_observation_processor,
+ initial_features=create_initial_features(observation=robot.observation_features),
+ use_videos=cfg.dataset.video,
+ ),
+ )
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
+ batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
@@ -301,10 +431,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
+ batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
# Load pretrained policy
-
if cfg.policy and cfg.policy.use_peft:
from peft import PeftModel
@@ -314,59 +444,83 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
cfg.policy.pretrained_path = None
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
-
policy = PeftModel.from_pretrained(policy, peft_path)
+
+ # it is not necessary to merge and unload but for methods that support merging,
+ # it brings inference performance benefits.
policy = policy.merge_and_unload()
else:
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
+ preprocessor = None
+ postprocessor = None
+ if cfg.policy is not None:
+ preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=cfg.policy,
+ pretrained_path=cfg.policy.pretrained_path,
+ dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
+ preprocessor_overrides={
+ "device_processor": {"device": cfg.policy.device},
+ "rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
+ },
+ )
+
robot.connect()
if teleop is not None:
teleop.connect()
listener, events = init_keyboard_listener()
- recorded_episodes = 0
- while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
- log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
- record_loop(
- robot=robot,
- events=events,
- fps=cfg.dataset.fps,
- teleop=teleop,
- policy=policy,
- dataset=dataset,
- control_time_s=cfg.dataset.episode_time_s,
- single_task=cfg.dataset.single_task,
- display_data=cfg.display_data,
- )
-
- # Execute a few seconds without recording to give time to manually reset the environment
- # Skip reset for the last episode to be recorded
- if not events["stop_recording"] and (
- (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
- ):
- log_say("Reset the environment", cfg.play_sounds)
+ with VideoEncodingManager(dataset):
+ recorded_episodes = 0
+ while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
+ log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
+ teleop_action_processor=teleop_action_processor,
+ robot_action_processor=robot_action_processor,
+ robot_observation_processor=robot_observation_processor,
teleop=teleop,
- control_time_s=cfg.dataset.reset_time_s,
+ policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ dataset=dataset,
+ control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
- if events["rerecord_episode"]:
- log_say("Re-record episode", cfg.play_sounds)
- events["rerecord_episode"] = False
- events["exit_early"] = False
- dataset.clear_episode_buffer()
- continue
+ # Execute a few seconds without recording to give time to manually reset the environment
+ # Skip reset for the last episode to be recorded
+ if not events["stop_recording"] and (
+ (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
+ ):
+ log_say("Reset the environment", cfg.play_sounds)
+ record_loop(
+ robot=robot,
+ events=events,
+ fps=cfg.dataset.fps,
+ teleop_action_processor=teleop_action_processor,
+ robot_action_processor=robot_action_processor,
+ robot_observation_processor=robot_observation_processor,
+ teleop=teleop,
+ control_time_s=cfg.dataset.reset_time_s,
+ single_task=cfg.dataset.single_task,
+ display_data=cfg.display_data,
+ )
- dataset.save_episode()
- recorded_episodes += 1
+ if events["rerecord_episode"]:
+ log_say("Re-record episode", cfg.play_sounds)
+ events["rerecord_episode"] = False
+ events["exit_early"] = False
+ dataset.clear_episode_buffer()
+ continue
+
+ dataset.save_episode()
+ recorded_episodes += 1
log_say("Stop recording", cfg.play_sounds, blocking=True)
@@ -384,5 +538,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
return dataset
-if __name__ == "__main__":
+def main():
+ register_third_party_devices()
record()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/replay.py b/src/lerobot/scripts/lerobot_replay.py
similarity index 64%
rename from src/lerobot/replay.py
rename to src/lerobot/scripts/lerobot_replay.py
index ef20c28ef..ffd7b2b22 100644
--- a/src/lerobot/replay.py
+++ b/src/lerobot/scripts/lerobot_replay.py
@@ -15,16 +15,28 @@
"""
Replays the actions of an episode from a dataset on a robot.
-Example:
+Examples:
```shell
-python -m lerobot.replay \
+lerobot-replay \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--dataset.repo_id=aliberts/record-test \
- --dataset.episode=2
+ --dataset.episode=0
```
+
+Example replay with bimanual so100:
+```shell
+lerobot-replay \
+ --robot.type=bi_so100_follower \
+ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
+ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
+ --robot.id=bimanual_follower \
+ --dataset.repo_id=${HF_USER}/bimanual-so100-handover-cube \
+ --dataset.episode=0
+```
+
"""
import logging
@@ -33,17 +45,23 @@ from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
-import draccus
-
+from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.processor import (
+ make_default_robot_action_processor,
+)
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
+ bi_so100_follower,
+ hope_jr,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
)
+from lerobot.utils.constants import ACTION
+from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
@@ -71,26 +89,36 @@ class ReplayConfig:
play_sounds: bool = True
-@draccus.wrap()
+@parser.wrap()
def replay(cfg: ReplayConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
+ robot_action_processor = make_default_robot_action_processor()
+
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
- actions = dataset.hf_dataset.select_columns("action")
+
+ # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
+ episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
+ actions = episode_frames.select_columns(ACTION)
+
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
- for idx in range(dataset.num_frames):
+ for idx in range(len(episode_frames)):
start_episode_t = time.perf_counter()
- action_array = actions[idx]["action"]
+ action_array = actions[idx][ACTION]
action = {}
- for i, name in enumerate(dataset.features["action"]["names"]):
+ for i, name in enumerate(dataset.features[ACTION]["names"]):
action[name] = action_array[i]
- robot.send_action(action)
+ robot_obs = robot.get_observation()
+
+ processed_action = robot_action_processor((action, robot_obs))
+
+ _ = robot.send_action(processed_action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / dataset.fps - dt_s)
@@ -98,5 +126,10 @@ def replay(cfg: ReplayConfig):
robot.disconnect()
-if __name__ == "__main__":
+def main():
+ register_third_party_devices()
replay()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/setup_motors.py b/src/lerobot/scripts/lerobot_setup_motors.py
similarity index 97%
rename from src/lerobot/setup_motors.py
rename to src/lerobot/scripts/lerobot_setup_motors.py
index c54582a1d..c1d256c21 100644
--- a/src/lerobot/setup_motors.py
+++ b/src/lerobot/scripts/lerobot_setup_motors.py
@@ -18,7 +18,7 @@ Helper to set motor ids and baudrate.
Example:
```shell
-python -m lerobot.setup_motors \
+lerobot-setup-motors \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem575E0031751
```
@@ -80,5 +80,9 @@ def setup_motors(cfg: SetupConfig):
device.setup_motors()
-if __name__ == "__main__":
+def main():
setup_motors()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py
new file mode 100644
index 000000000..0a418f3bc
--- /dev/null
+++ b/src/lerobot/scripts/lerobot_teleoperate.py
@@ -0,0 +1,224 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Simple script to control a robot from teleoperation.
+
+Example:
+
+```shell
+lerobot-teleoperate \
+ --robot.type=so101_follower \
+ --robot.port=/dev/tty.usbmodem58760431541 \
+ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
+ --robot.id=black \
+ --teleop.type=so101_leader \
+ --teleop.port=/dev/tty.usbmodem58760431551 \
+ --teleop.id=blue \
+ --display_data=true
+```
+
+Example teleoperation with bimanual so100:
+
+```shell
+lerobot-teleoperate \
+ --robot.type=bi_so100_follower \
+ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \
+ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \
+ --robot.id=bimanual_follower \
+ --robot.cameras='{
+ left: {"type": "opencv", "index_or_path": 0, "width": 1920, "height": 1080, "fps": 30},
+ top: {"type": "opencv", "index_or_path": 1, "width": 1920, "height": 1080, "fps": 30},
+ right: {"type": "opencv", "index_or_path": 2, "width": 1920, "height": 1080, "fps": 30}
+ }' \
+ --teleop.type=bi_so100_leader \
+ --teleop.left_arm_port=/dev/tty.usbmodem5A460828611 \
+ --teleop.right_arm_port=/dev/tty.usbmodem5A460826981 \
+ --teleop.id=bimanual_leader \
+ --display_data=true
+```
+
+"""
+
+import logging
+import time
+from dataclasses import asdict, dataclass
+from pprint import pformat
+
+import rerun as rr
+
+from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
+from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
+from lerobot.configs import parser
+from lerobot.processor import (
+ RobotAction,
+ RobotObservation,
+ RobotProcessorPipeline,
+ make_default_processors,
+)
+from lerobot.robots import ( # noqa: F401
+ Robot,
+ RobotConfig,
+ bi_so100_follower,
+ hope_jr,
+ koch_follower,
+ make_robot_from_config,
+ so100_follower,
+ so101_follower,
+)
+from lerobot.teleoperators import ( # noqa: F401
+ Teleoperator,
+ TeleoperatorConfig,
+ bi_so100_leader,
+ gamepad,
+ homunculus,
+ koch_leader,
+ make_teleoperator_from_config,
+ so100_leader,
+ so101_leader,
+)
+from lerobot.utils.import_utils import register_third_party_devices
+from lerobot.utils.robot_utils import busy_wait
+from lerobot.utils.utils import init_logging, move_cursor_up
+from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+
+
+@dataclass
+class TeleoperateConfig:
+ # TODO: pepijn, steven: if more robots require multiple teleoperators (like lekiwi) its good to make this possibele in teleop.py and record.py with List[Teleoperator]
+ teleop: TeleoperatorConfig
+ robot: RobotConfig
+ # Limit the maximum frames per second.
+ fps: int = 60
+ teleop_time_s: float | None = None
+ # Display all cameras on screen
+ display_data: bool = False
+
+
+def teleop_loop(
+ teleop: Teleoperator,
+ robot: Robot,
+ fps: int,
+ teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
+ robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
+ robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
+ display_data: bool = False,
+ duration: float | None = None,
+):
+ """
+ This function continuously reads actions from a teleoperation device, processes them through optional
+ pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a
+ specified frequency until a set duration is reached or it is manually interrupted.
+
+ Args:
+ teleop: The teleoperator device instance providing control actions.
+ robot: The robot instance being controlled.
+ fps: The target frequency for the control loop in frames per second.
+ display_data: If True, fetches robot observations and displays them in the console and Rerun.
+ duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
+ teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
+ robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
+ robot_observation_processor: An optional pipeline to process raw observations from the robot.
+ """
+
+ display_len = max(len(key) for key in robot.action_features)
+ start = time.perf_counter()
+
+ while True:
+ loop_start = time.perf_counter()
+
+ # Get robot observation
+ # Not really needed for now other than for visualization
+ # teleop_action_processor can take None as an observation
+ # given that it is the identity processor as default
+ obs = robot.get_observation()
+
+ # Get teleop action
+ raw_action = teleop.get_action()
+
+ # Process teleop action through pipeline
+ teleop_action = teleop_action_processor((raw_action, obs))
+
+ # Process action for robot through pipeline
+ robot_action_to_send = robot_action_processor((teleop_action, obs))
+
+ # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
+ _ = robot.send_action(robot_action_to_send)
+
+ if display_data:
+ # Process robot observation through pipeline
+ obs_transition = robot_observation_processor(obs)
+
+ log_rerun_data(
+ observation=obs_transition,
+ action=teleop_action,
+ )
+
+ print("\n" + "-" * (display_len + 10))
+ print(f"{'NAME':<{display_len}} | {'NORM':>7}")
+ # Display the final robot action that was sent
+ for motor, value in robot_action_to_send.items():
+ print(f"{motor:<{display_len}} | {value:>7.2f}")
+ move_cursor_up(len(robot_action_to_send) + 5)
+
+ dt_s = time.perf_counter() - loop_start
+ busy_wait(1 / fps - dt_s)
+ loop_s = time.perf_counter() - loop_start
+ print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
+
+ if duration is not None and time.perf_counter() - start >= duration:
+ return
+
+
+@parser.wrap()
+def teleoperate(cfg: TeleoperateConfig):
+ init_logging()
+ logging.info(pformat(asdict(cfg)))
+ if cfg.display_data:
+ init_rerun(session_name="teleoperation")
+
+ teleop = make_teleoperator_from_config(cfg.teleop)
+ robot = make_robot_from_config(cfg.robot)
+ teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
+
+ teleop.connect()
+ robot.connect()
+
+ try:
+ teleop_loop(
+ teleop=teleop,
+ robot=robot,
+ fps=cfg.fps,
+ display_data=cfg.display_data,
+ duration=cfg.teleop_time_s,
+ teleop_action_processor=teleop_action_processor,
+ robot_action_processor=robot_action_processor,
+ robot_observation_processor=robot_observation_processor,
+ )
+ except KeyboardInterrupt:
+ pass
+ finally:
+ if cfg.display_data:
+ rr.rerun_shutdown()
+ teleop.disconnect()
+ robot.disconnect()
+
+
+def main():
+ register_third_party_devices()
+ teleoperate()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/lerobot_train.py
similarity index 72%
rename from src/lerobot/scripts/train.py
rename to src/lerobot/scripts/lerobot_train.py
index e724627ac..c608c855f 100644
--- a/src/lerobot/scripts/train.py
+++ b/src/lerobot/scripts/lerobot_train.py
@@ -31,11 +31,13 @@ from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
+from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
-from lerobot.policies.factory import make_policy
+from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
-from lerobot.scripts.eval import eval_policy
+from lerobot.rl.wandb_utils import WandBLogger
+from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
@@ -51,7 +53,6 @@ from lerobot.utils.utils import (
has_method,
init_logging,
)
-from lerobot.utils.wandb_utils import WandBLogger
def update_policy(
@@ -65,6 +66,28 @@ def update_policy(
use_amp: bool = False,
lock=None,
) -> tuple[MetricsTracker, dict]:
+ """
+ Performs a single training step to update the policy's weights.
+
+ This function executes the forward and backward passes, clips gradients, and steps the optimizer and
+ learning rate scheduler. It also handles mixed-precision training via a GradScaler.
+
+ Args:
+ train_metrics: A MetricsTracker instance to record training statistics.
+ policy: The policy model to be trained.
+ batch: A batch of training data.
+ optimizer: The optimizer used to update the policy's parameters.
+ grad_clip_norm: The maximum norm for gradient clipping.
+ grad_scaler: The GradScaler for automatic mixed-precision training.
+ lr_scheduler: An optional learning rate scheduler.
+ use_amp: A boolean indicating whether to use automatic mixed precision.
+ lock: An optional lock for thread-safe optimizer updates.
+
+ Returns:
+ A tuple containing:
+ - The updated MetricsTracker with new statistics for this step.
+ - A dictionary of outputs from the policy's forward pass, for logging purposes.
+ """
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
@@ -176,6 +199,20 @@ def wrap_policy_in_peft_model(cfg, policy):
@parser.wrap()
def train(cfg: TrainPipelineConfig):
+ """
+ Main function to train a policy.
+
+ This function orchestrates the entire training pipeline, including:
+ - Setting up logging, seeding, and device configuration.
+ - Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
+ - Handling resumption from a checkpoint.
+ - Running the main training loop, which involves fetching data batches and calling `update_policy`.
+ - Periodically logging metrics, saving model checkpoints, and evaluating the policy.
+ - Pushing the final trained model to the Hugging Face Hub if configured.
+
+ Args:
+ cfg: A `TrainPipelineConfig` object containing all training configurations.
+ """
cfg.validate()
logging.info(pformat(cfg.to_dict()))
@@ -214,6 +251,37 @@ def train(cfg: TrainPipelineConfig):
logging.info("Using PEFT! Wrapping model.")
policy = wrap_policy_in_peft_model(cfg, policy)
+ # Create processors - only provide dataset_stats if not resuming from saved processors
+ processor_kwargs = {}
+ postprocessor_kwargs = {}
+ if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
+ # Only provide dataset_stats when not resuming from saved processor state
+ processor_kwargs["dataset_stats"] = dataset.meta.stats
+
+ if cfg.policy.pretrained_path is not None:
+ processor_kwargs["preprocessor_overrides"] = {
+ "device_processor": {"device": device.type},
+ "normalizer_processor": {
+ "stats": dataset.meta.stats,
+ "features": {**policy.config.input_features, **policy.config.output_features},
+ "norm_map": policy.config.normalization_mapping,
+ },
+ }
+ postprocessor_kwargs["postprocessor_overrides"] = {
+ "unnormalizer_processor": {
+ "stats": dataset.meta.stats,
+ "features": policy.config.output_features,
+ "norm_map": policy.config.normalization_mapping,
+ },
+ }
+
+ preprocessor, postprocessor = make_pre_post_processors(
+ policy_cfg=cfg.policy,
+ pretrained_path=cfg.policy.pretrained_path,
+ **processor_kwargs,
+ **postprocessor_kwargs,
+ )
+
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
@@ -239,7 +307,8 @@ def train(cfg: TrainPipelineConfig):
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
- dataset.episode_data_index,
+ dataset.meta.episodes["dataset_from_index"],
+ dataset.meta.episodes["dataset_to_index"],
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
@@ -251,10 +320,11 @@ def train(cfg: TrainPipelineConfig):
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
- shuffle=shuffle,
+ shuffle=shuffle and not cfg.dataset.streaming,
sampler=sampler,
- pin_memory=device.type != "cpu",
+ pin_memory=device.type == "cuda",
drop_last=False,
+ prefetch_factor=2,
)
dl_iter = cycle(dataloader)
@@ -276,12 +346,9 @@ def train(cfg: TrainPipelineConfig):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
+ batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
- for key in batch:
- if isinstance(batch[key], torch.Tensor):
- batch[key] = batch[key].to(device, non_blocking=True)
-
train_tracker, output_dict = update_policy(
train_tracker,
policy,
@@ -313,7 +380,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
- save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
+ save_checkpoint(
+ checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
+ )
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
@@ -325,15 +394,25 @@ def train(cfg: TrainPipelineConfig):
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
- eval_info = eval_policy(
- eval_env,
- policy,
- cfg.eval.n_episodes,
+ eval_info = eval_policy_all(
+ envs=eval_env, # dict[suite][task_id] -> vec_env
+ policy=policy,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ n_episodes=cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
+ max_parallel_tasks=cfg.env.max_parallel_tasks,
)
+ # overall metrics (suite-agnostic)
+ aggregated = eval_info["overall"]
+ # optional: per-suite logging
+ for suite, suite_info in eval_info.items():
+ logging.info("Suite %s aggregated: %s", suite, suite_info)
+
+ # meters/tracker
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
@@ -342,23 +421,28 @@ def train(cfg: TrainPipelineConfig):
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
- eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
- eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
- eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
- logging.info(eval_tracker)
+ eval_tracker.eval_s = aggregated.pop("eval_s")
+ eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
+ eval_tracker.pc_success = aggregated.pop("pc_success")
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
- wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
+ wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
if eval_env:
- eval_env.close()
+ close_envs(eval_env)
logging.info("End of training")
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
+ preprocessor.push_to_hub(cfg.policy.repo_id)
+ postprocessor.push_to_hub(cfg.policy.repo_id)
+
+
+def main():
+ init_logging()
+ train()
if __name__ == "__main__":
- init_logging()
- train()
+ main()
diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py
deleted file mode 100644
index 673043b6e..000000000
--- a/src/lerobot/scripts/rl/gym_manipulator.py
+++ /dev/null
@@ -1,2262 +0,0 @@
-# !/usr/bin/env python
-
-# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-Robot Environment for LeRobot Manipulation Tasks
-
-This module provides a comprehensive gym-compatible environment for robot manipulation
-with support for:
-- Multiple robot types (SO100, SO101, Koch and Moss)
-- Human intervention via leader-follower control or gamepad
-
-- End-effector and joint space control
-- Image processing (cropping and resizing)
-
-The environment is built using a composable wrapper pattern where each wrapper
-adds specific functionality to the base RobotEnv.
-
-Example:
- env = make_robot_env(cfg)
- obs, info = env.reset()
- action = policy.select_action(obs)
- obs, reward, terminated, truncated, info = env.step(action)
-"""
-
-import logging
-import time
-from collections import deque
-from threading import Lock
-from typing import Annotated, Any, Sequence
-
-import gymnasium as gym
-import numpy as np
-import torch
-import torchvision.transforms.functional as F # noqa: N812
-
-from lerobot.cameras import opencv # noqa: F401
-from lerobot.configs import parser
-from lerobot.envs.configs import EnvConfig
-from lerobot.envs.utils import preprocess_observation
-from lerobot.model.kinematics import RobotKinematics
-from lerobot.robots import ( # noqa: F401
- RobotConfig,
- make_robot_from_config,
- so100_follower,
-)
-from lerobot.teleoperators import (
- gamepad, # noqa: F401
- keyboard, # noqa: F401
- make_teleoperator_from_config,
- so101_leader, # noqa: F401
-)
-from lerobot.teleoperators.gamepad.teleop_gamepad import GamepadTeleop
-from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardEndEffectorTeleop
-from lerobot.utils.robot_utils import busy_wait
-from lerobot.utils.utils import log_say
-
-logging.basicConfig(level=logging.INFO)
-
-
-def reset_follower_position(robot_arm, target_position):
- current_position_dict = robot_arm.bus.sync_read("Present_Position")
- current_position = np.array(
- [current_position_dict[name] for name in current_position_dict], dtype=np.float32
- )
- trajectory = torch.from_numpy(
- np.linspace(current_position, target_position, 50)
- ) # NOTE: 30 is just an arbitrary number
- for pose in trajectory:
- action_dict = dict(zip(current_position_dict, pose, strict=False))
- robot_arm.bus.sync_write("Goal_Position", action_dict)
- busy_wait(0.015)
-
-
-class TorchBox(gym.spaces.Box):
- """
- A version of gym.spaces.Box that handles PyTorch tensors.
-
- This class extends gym.spaces.Box to work with PyTorch tensors,
- providing compatibility between NumPy arrays and PyTorch tensors.
- """
-
- def __init__(
- self,
- low: float | Sequence[float] | np.ndarray,
- high: float | Sequence[float] | np.ndarray,
- shape: Sequence[int] | None = None,
- np_dtype: np.dtype | type = np.float32,
- torch_dtype: torch.dtype = torch.float32,
- device: str = "cpu",
- seed: int | np.random.Generator | None = None,
- ) -> None:
- """
- Initialize the PyTorch-compatible Box space.
-
- Args:
- low: Lower bounds of the space.
- high: Upper bounds of the space.
- shape: Shape of the space. If None, inferred from low and high.
- np_dtype: NumPy data type for internal storage.
- torch_dtype: PyTorch data type for tensor conversion.
- device: PyTorch device for returned tensors.
- seed: Random seed for sampling.
- """
- super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed)
- self.torch_dtype = torch_dtype
- self.device = device
-
- def sample(self) -> torch.Tensor:
- """
- Sample a random point from the space.
-
- Returns:
- A PyTorch tensor within the space bounds.
- """
- arr = super().sample()
- return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device)
-
- def contains(self, x: torch.Tensor) -> bool:
- """
- Check if a tensor is within the space bounds.
-
- Args:
- x: The PyTorch tensor to check.
-
- Returns:
- Boolean indicating whether the tensor is within bounds.
- """
- # Move to CPU/numpy and cast to the internal dtype
- arr = x.detach().cpu().numpy().astype(self.dtype, copy=False)
- return super().contains(arr)
-
- def seed(self, seed: int | np.random.Generator | None = None):
- """
- Set the random seed for sampling.
-
- Args:
- seed: The random seed to use.
-
- Returns:
- List containing the seed.
- """
- super().seed(seed)
- return [seed]
-
- def __repr__(self) -> str:
- """
- Return a string representation of the space.
-
- Returns:
- Formatted string with space details.
- """
- return (
- f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, "
- f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})"
- )
-
-
-class TorchActionWrapper(gym.Wrapper):
- """
- Wrapper that changes the action space to use PyTorch tensors.
-
- This wrapper modifies the action space to return PyTorch tensors when sampled
- and handles converting PyTorch actions to NumPy when stepping the environment.
- """
-
- def __init__(self, env: gym.Env, device: str):
- """
- Initialize the PyTorch action space wrapper.
-
- Args:
- env: The environment to wrap.
- device: The PyTorch device to use for tensor operations.
- """
- super().__init__(env)
- self.action_space = TorchBox(
- low=env.action_space.low,
- high=env.action_space.high,
- shape=env.action_space.shape,
- torch_dtype=torch.float32,
- device=torch.device("cpu"),
- )
-
- def step(self, action: torch.Tensor):
- """
- Step the environment with a PyTorch tensor action.
-
- This method handles conversion from PyTorch tensors to NumPy arrays
- for compatibility with the underlying environment.
-
- Args:
- action: PyTorch tensor action to take.
-
- Returns:
- Tuple of (observation, reward, terminated, truncated, info).
- """
- if action.dim() == 2:
- action = action.squeeze(0)
- action = action.detach().cpu().numpy()
- return self.env.step(action)
-
-
-class RobotEnv(gym.Env):
- """
- Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
-
- This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
- and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
- sensors and configuration.
- """
-
- def __init__(
- self,
- robot,
- use_gripper: bool = False,
- display_cameras: bool = False,
- ):
- """
- Initialize the RobotEnv environment.
-
- The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
- supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
-
- Args:
- robot: The robot interface object used to connect and interact with the physical robot.
- display_cameras: If True, the robot's camera feeds will be displayed during execution.
- """
- super().__init__()
-
- self.robot = robot
- self.display_cameras = display_cameras
-
- # Connect to the robot if not already connected.
- if not self.robot.is_connected:
- self.robot.connect()
-
- # Episode tracking.
- self.current_step = 0
- self.episode_data = None
-
- self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors]
- self._image_keys = self.robot.cameras.keys()
-
- self.current_observation = None
-
- self.use_gripper = use_gripper
-
- self._setup_spaces()
-
- def _get_observation(self) -> dict[str, np.ndarray]:
- """Helper to convert a dictionary from bus.sync_read to an ordered numpy array."""
- obs_dict = self.robot.get_observation()
- joint_positions = np.array([obs_dict[name] for name in self._joint_names])
-
- images = {key: obs_dict[key] for key in self._image_keys}
- self.current_observation = {"agent_pos": joint_positions, "pixels": images}
-
- def _setup_spaces(self):
- """
- Dynamically configure the observation and action spaces based on the robot's capabilities.
-
- Observation Space:
- - For keys with "image": A Box space with pixel values ranging from 0 to 255.
- - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
-
- Action Space:
- - The action space is defined as a Box space representing joint position commands. It is defined as relative (delta)
- or absolute, based on the configuration.
- """
- self._get_observation()
-
- observation_spaces = {}
-
- # Define observation spaces for images and other states.
- if "pixels" in self.current_observation:
- prefix = "observation.images"
- observation_spaces = {
- f"{prefix}.{key}": gym.spaces.Box(
- low=0, high=255, shape=self.current_observation["pixels"][key].shape, dtype=np.uint8
- )
- for key in self.current_observation["pixels"]
- }
-
- observation_spaces["observation.state"] = gym.spaces.Box(
- low=0,
- high=10,
- shape=self.current_observation["agent_pos"].shape,
- dtype=np.float32,
- )
-
- self.observation_space = gym.spaces.Dict(observation_spaces)
-
- # Define the action space for joint positions along with setting an intervention flag.
- action_dim = 3
- bounds = {}
- bounds["min"] = -np.ones(action_dim)
- bounds["max"] = np.ones(action_dim)
-
- if self.use_gripper:
- action_dim += 1
- bounds["min"] = np.concatenate([bounds["min"], [0]])
- bounds["max"] = np.concatenate([bounds["max"], [2]])
-
- self.action_space = gym.spaces.Box(
- low=bounds["min"],
- high=bounds["max"],
- shape=(action_dim,),
- dtype=np.float32,
- )
-
- def reset(self, seed=None, options=None) -> tuple[dict[str, np.ndarray], dict[str, Any]]:
- """
- Reset the environment to its initial state.
- This method resets the step counter and clears any episodic data.
-
- Args:
- seed: A seed for random number generation to ensure reproducibility.
- options: Additional options to influence the reset behavior.
-
- Returns:
- A tuple containing:
- - observation (dict): The initial sensor observation.
- - info (dict): A dictionary with supplementary information, including the key "is_intervention".
- """
- super().reset(seed=seed, options=options)
-
- self.robot.reset()
-
- # Reset episode tracking variables.
- self.current_step = 0
- self.episode_data = None
- self.current_observation = None
- self._get_observation()
- return self.current_observation, {"is_intervention": False}
-
- def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
- """
- Execute a single step within the environment using the specified action.
-
- The provided action is processed and sent to the robot as joint position commands
- that may be either absolute values or deltas based on the environment configuration.
-
- Args:
- action: The commanded joint positions as a numpy array or torch tensor.
-
- Returns:
- A tuple containing:
- - observation (dict): The new sensor observation after taking the step.
- - reward (float): The step reward (default is 0.0 within this wrapper).
- - terminated (bool): True if the episode has reached a terminal state.
- - truncated (bool): True if the episode was truncated (e.g., time constraints).
- - info (dict): Additional debugging information including intervention status.
- """
- action_dict = {"delta_x": action[0], "delta_y": action[1], "delta_z": action[2]}
-
- # 1.0 action corresponds to no-op action
- action_dict["gripper"] = action[3] if self.use_gripper else 1.0
-
- self.robot.send_action(action_dict)
-
- self._get_observation()
-
- if self.display_cameras:
- self.render()
-
- self.current_step += 1
-
- reward = 0.0
- terminated = False
- truncated = False
-
- return (
- self.current_observation,
- reward,
- terminated,
- truncated,
- {"is_intervention": False},
- )
-
- def render(self):
- """
- Render the current state of the environment by displaying the robot's camera feeds.
- """
- import cv2
-
- image_keys = [key for key in self.current_observation if "image" in key]
-
- for key in image_keys:
- cv2.imshow(key, cv2.cvtColor(self.current_observation[key].numpy(), cv2.COLOR_RGB2BGR))
- cv2.waitKey(1)
-
- def close(self):
- """
- Close the environment and clean up resources by disconnecting the robot.
-
- If the robot is currently connected, this method properly terminates the connection to ensure that all
- associated resources are released.
- """
- if self.robot.is_connected:
- self.robot.disconnect()
-
-
-class AddJointVelocityToObservation(gym.ObservationWrapper):
- """
- Wrapper that adds joint velocity information to the observation.
-
- This wrapper computes joint velocities by tracking changes in joint positions over time,
- and extends the observation space to include these velocities.
- """
-
- def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6):
- """
- Initialize the joint velocity wrapper.
-
- Args:
- env: The environment to wrap.
- joint_velocity_limits: Maximum expected joint velocity for space bounds.
- fps: Frames per second used to calculate velocity (position delta / time).
- num_dof: Number of degrees of freedom (joints) in the robot.
- """
- super().__init__(env)
-
- # Extend observation space to include joint velocities
- old_low = self.observation_space["observation.state"].low
- old_high = self.observation_space["observation.state"].high
- old_shape = self.observation_space["observation.state"].shape
-
- self.last_joint_positions = np.zeros(num_dof)
-
- new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits])
- new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits])
-
- new_shape = (old_shape[0] + num_dof,)
-
- self.observation_space["observation.state"] = gym.spaces.Box(
- low=new_low,
- high=new_high,
- shape=new_shape,
- dtype=np.float32,
- )
-
- self.dt = 1.0 / fps
-
- def observation(self, observation):
- """
- Add joint velocity information to the observation.
-
- Args:
- observation: The original observation from the environment.
-
- Returns:
- The modified observation with joint velocities.
- """
- joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt
- self.last_joint_positions = observation["agent_pos"]
- observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1)
- return observation
-
-
-class AddCurrentToObservation(gym.ObservationWrapper):
- """
- Wrapper that adds motor current information to the observation.
-
- This wrapper extends the observation space to include the current values
- from each motor, providing information about the forces being applied.
- """
-
- def __init__(self, env, max_current=500, num_dof=6):
- """
- Initialize the current observation wrapper.
-
- Args:
- env: The environment to wrap.
- max_current: Maximum expected current for space bounds.
- num_dof: Number of degrees of freedom (joints) in the robot.
- """
- super().__init__(env)
-
- # Extend observation space to include joint velocities
- old_low = self.observation_space["observation.state"].low
- old_high = self.observation_space["observation.state"].high
- old_shape = self.observation_space["observation.state"].shape
-
- new_low = np.concatenate([old_low, np.zeros(num_dof)])
- new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
-
- new_shape = (old_shape[0] + num_dof,)
-
- self.observation_space["observation.state"] = gym.spaces.Box(
- low=new_low,
- high=new_high,
- shape=new_shape,
- dtype=np.float32,
- )
-
- def observation(self, observation):
- """
- Add current information to the observation.
-
- Args:
- observation: The original observation from the environment.
-
- Returns:
- The modified observation with current values.
- """
- present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
- present_current_observation = np.array(
- [present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors]
- )
- observation["agent_pos"] = np.concatenate(
- [observation["agent_pos"], present_current_observation], axis=-1
- )
- return observation
-
-
-class RewardWrapper(gym.Wrapper):
- def __init__(self, env, reward_classifier, device="cuda"):
- """
- Wrapper to add reward prediction to the environment using a trained classifier.
-
- Args:
- env: The environment to wrap.
- reward_classifier: The reward classifier model.
- device: The device to run the model on.
- """
- self.env = env
-
- self.device = device
-
- self.reward_classifier = torch.compile(reward_classifier)
- self.reward_classifier.to(self.device)
-
- def step(self, action):
- """
- Execute a step and compute the reward using the classifier.
-
- Args:
- action: The action to take in the environment.
-
- Returns:
- Tuple of (observation, reward, terminated, truncated, info).
- """
- observation, _, terminated, truncated, info = self.env.step(action)
-
- images = {}
- for key in observation:
- if "image" in key:
- images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda"))
- if images[key].dim() == 3:
- images[key] = images[key].unsqueeze(0)
-
- start_time = time.perf_counter()
- with torch.inference_mode():
- success = (
- self.reward_classifier.predict_reward(images, threshold=0.7)
- if self.reward_classifier is not None
- else 0.0
- )
- info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time)
-
- reward = 0.0
- if success == 1.0:
- terminated = True
- reward = 1.0
-
- return observation, reward, terminated, truncated, info
-
- def reset(self, seed=None, options=None):
- """
- Reset the environment.
-
- Args:
- seed: Random seed for reproducibility.
- options: Additional reset options.
-
- Returns:
- The initial observation and info from the wrapped environment.
- """
- return self.env.reset(seed=seed, options=options)
-
-
-class TimeLimitWrapper(gym.Wrapper):
- """
- Wrapper that adds a time limit to episodes and tracks execution time.
-
- This wrapper terminates episodes after a specified time has elapsed, providing
- better control over episode length.
- """
-
- def __init__(self, env, control_time_s, fps):
- """
- Initialize the time limit wrapper.
-
- Args:
- env: The environment to wrap.
- control_time_s: Maximum episode duration in seconds.
- fps: Frames per second for calculating the maximum number of steps.
- """
- self.env = env
- self.control_time_s = control_time_s
- self.fps = fps
-
- self.last_timestamp = 0.0
- self.episode_time_in_s = 0.0
-
- self.max_episode_steps = int(self.control_time_s * self.fps)
-
- self.current_step = 0
-
- def step(self, action):
- """
- Step the environment and track time elapsed.
-
- Args:
- action: The action to take in the environment.
-
- Returns:
- Tuple of (observation, reward, terminated, truncated, info).
- """
- obs, reward, terminated, truncated, info = self.env.step(action)
- time_since_last_step = time.perf_counter() - self.last_timestamp
- self.episode_time_in_s += time_since_last_step
- self.last_timestamp = time.perf_counter()
- self.current_step += 1
- # check if last timestep took more time than the expected fps
- if 1.0 / time_since_last_step < self.fps:
- logging.debug(f"Current timestep exceeded expected fps {self.fps}")
-
- if self.current_step >= self.max_episode_steps:
- terminated = True
- return obs, reward, terminated, truncated, info
-
- def reset(self, seed=None, options=None):
- """
- Reset the environment and time tracking.
-
- Args:
- seed: Random seed for reproducibility.
- options: Additional reset options.
-
- Returns:
- The initial observation and info from the wrapped environment.
- """
- self.episode_time_in_s = 0.0
- self.last_timestamp = time.perf_counter()
- self.current_step = 0
- return self.env.reset(seed=seed, options=options)
-
-
-class ImageCropResizeWrapper(gym.Wrapper):
- """
- Wrapper that crops and resizes image observations.
-
- This wrapper processes image observations to focus on relevant regions by
- cropping and then resizing to a standard size.
- """
-
- def __init__(
- self,
- env,
- crop_params_dict: dict[str, Annotated[tuple[int], 4]],
- resize_size=None,
- ):
- """
- Initialize the image crop and resize wrapper.
-
- Args:
- env: The environment to wrap.
- crop_params_dict: Dictionary mapping image observation keys to crop parameters
- (top, left, height, width).
- resize_size: Target size for resized images (height, width). Defaults to (128, 128).
- """
- super().__init__(env)
- self.env = env
- self.crop_params_dict = crop_params_dict
- print(f"obs_keys , {self.env.observation_space}")
- print(f"crop params dict {crop_params_dict.keys()}")
- for key_crop in crop_params_dict:
- if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
- raise ValueError(f"Key {key_crop} not in observation space")
- for key in crop_params_dict:
- new_shape = (3, resize_size[0], resize_size[1])
- self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
-
- self.resize_size = resize_size
- if self.resize_size is None:
- self.resize_size = (128, 128)
-
- def step(self, action):
- """
- Step the environment and process image observations.
-
- Args:
- action: The action to take in the environment.
-
- Returns:
- Tuple of (observation, reward, terminated, truncated, info) with processed images.
- """
- obs, reward, terminated, truncated, info = self.env.step(action)
- for k in self.crop_params_dict:
- device = obs[k].device
- if obs[k].dim() >= 3:
- # Reshape to combine height and width dimensions for easier calculation
- batch_size = obs[k].size(0)
- channels = obs[k].size(1)
- flattened_spatial_dims = obs[k].view(batch_size, channels, -1)
-
- # Calculate standard deviation across spatial dimensions (H, W)
- # If any channel has std=0, all pixels in that channel have the same value
- # This is helpful if one camera mistakenly covered or the image is black
- std_per_channel = torch.std(flattened_spatial_dims, dim=2)
- if (std_per_channel <= 0.02).any():
- logging.warning(
- f"Potential hardware issue detected: All pixels have the same value in observation {k}"
- )
-
- if device == torch.device("mps:0"):
- obs[k] = obs[k].cpu()
-
- obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
- obs[k] = F.resize(obs[k], self.resize_size)
- # TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1]
- obs[k] = obs[k].clamp(0.0, 1.0)
- obs[k] = obs[k].to(device)
-
- return obs, reward, terminated, truncated, info
-
- def reset(self, seed=None, options=None):
- """
- Reset the environment and process image observations.
-
- Args:
- seed: Random seed for reproducibility.
- options: Additional reset options.
-
- Returns:
- Tuple of (observation, info) with processed images.
- """
- obs, info = self.env.reset(seed=seed, options=options)
- for k in self.crop_params_dict:
- device = obs[k].device
- if device == torch.device("mps:0"):
- obs[k] = obs[k].cpu()
- obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
- obs[k] = F.resize(obs[k], self.resize_size)
- obs[k] = obs[k].clamp(0.0, 1.0)
- obs[k] = obs[k].to(device)
- return obs, info
-
-
-class ConvertToLeRobotObservation(gym.ObservationWrapper):
- """
- Wrapper that converts standard observations to LeRobot format.
-
- This wrapper processes observations to match the expected format for LeRobot,
- including normalizing image values and moving tensors to the specified device.
- """
-
- def __init__(self, env, device: str = "cpu"):
- """
- Initialize the LeRobot observation converter.
-
- Args:
- env: The environment to wrap.
- device: Target device for the observation tensors.
- """
- super().__init__(env)
-
- self.device = torch.device(device)
-
- def observation(self, observation):
- """
- Convert observations to LeRobot format.
-
- Args:
- observation: The original observation from the environment.
-
- Returns:
- The processed observation with normalized images and proper tensor formats.
- """
- observation = preprocess_observation(observation)
- observation = {
- key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
- for key in observation
- }
- return observation
-
-
-class ResetWrapper(gym.Wrapper):
- """
- Wrapper that handles environment reset procedures.
-
- This wrapper provides additional functionality during environment reset,
- including the option to reset to a fixed pose or allow manual reset.
- """
-
- def __init__(
- self,
- env: RobotEnv,
- reset_pose: np.ndarray | None = None,
- reset_time_s: float = 5,
- ):
- """
- Initialize the reset wrapper.
-
- Args:
- env: The environment to wrap.
- reset_pose: Fixed joint positions to reset to. If None, manual reset is used.
- reset_time_s: Time in seconds to wait after reset or allowed for manual reset.
- """
- super().__init__(env)
- self.reset_time_s = reset_time_s
- self.reset_pose = reset_pose
- self.robot = self.unwrapped.robot
-
- def reset(self, *, seed=None, options=None):
- """
- Reset the environment with either fixed or manual reset procedure.
-
- If reset_pose is provided, the robot will move to that position.
- Otherwise, manual teleoperation control is allowed for reset_time_s seconds.
-
- Args:
- seed: Random seed for reproducibility.
- options: Additional reset options.
-
- Returns:
- The initial observation and info from the wrapped environment.
- """
- start_time = time.perf_counter()
- if self.reset_pose is not None:
- log_say("Reset the environment.", play_sounds=True)
- reset_follower_position(self.unwrapped.robot, self.reset_pose)
- log_say("Reset the environment done.", play_sounds=True)
-
- if hasattr(self.env, "robot_leader"):
- self.env.robot_leader.bus.sync_write("Torque_Enable", 1)
- log_say("Reset the leader robot.", play_sounds=True)
- reset_follower_position(self.env.robot_leader, self.reset_pose)
- log_say("Reset the leader robot done.", play_sounds=True)
- else:
- log_say(
- f"Manually reset the environment for {self.reset_time_s} seconds.",
- play_sounds=True,
- )
- start_time = time.perf_counter()
- while time.perf_counter() - start_time < self.reset_time_s:
- action = self.env.robot_leader.get_action()
- self.unwrapped.robot.send_action(action)
-
- log_say("Manual reset of the environment done.", play_sounds=True)
-
- busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
-
- return super().reset(seed=seed, options=options)
-
-
-class BatchCompatibleWrapper(gym.ObservationWrapper):
- """
- Wrapper that ensures observations are compatible with batch processing.
-
- This wrapper adds a batch dimension to observations that don't already have one,
- making them compatible with models that expect batched inputs.
- """
-
- def __init__(self, env):
- """
- Initialize the batch compatibility wrapper.
-
- Args:
- env: The environment to wrap.
- """
- super().__init__(env)
-
- def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
- """
- Add batch dimensions to observations if needed.
-
- Args:
- observation: Dictionary of observation tensors.
-
- Returns:
- Dictionary of observation tensors with batch dimensions.
- """
- for key in observation:
- if "image" in key and observation[key].dim() == 3:
- observation[key] = observation[key].unsqueeze(0)
- if "state" in key and observation[key].dim() == 1:
- observation[key] = observation[key].unsqueeze(0)
- if "velocity" in key and observation[key].dim() == 1:
- observation[key] = observation[key].unsqueeze(0)
- return observation
-
-
-class GripperPenaltyWrapper(gym.RewardWrapper):
- """
- Wrapper that adds penalties for inefficient gripper commands.
-
- This wrapper modifies rewards to discourage excessive gripper movement
- or commands that attempt to move the gripper beyond its physical limits.
- """
-
- def __init__(self, env, penalty: float = -0.1):
- """
- Initialize the gripper penalty wrapper.
-
- Args:
- env: The environment to wrap.
- penalty: Negative reward value to apply for inefficient gripper actions.
- """
- super().__init__(env)
- self.penalty = penalty
- self.last_gripper_state = None
-
- def reward(self, reward, action):
- """
- Apply penalties to reward based on gripper actions.
-
- Args:
- reward: The original reward from the environment.
- action: The action that was taken.
-
- Returns:
- Modified reward with penalty applied if necessary.
- """
- gripper_state_normalized = self.last_gripper_state / self.unwrapped.robot.config.max_gripper_pos
-
- action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND
-
- gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or (
- gripper_state_normalized > 0.75 and action_normalized < -0.5
- )
-
- return reward + self.penalty * int(gripper_penalty_bool)
-
- def step(self, action):
- """
- Step the environment and apply gripper penalties.
-
- Args:
- action: The action to take in the environment.
-
- Returns:
- Tuple of (observation, reward, terminated, truncated, info) with penalty applied.
- """
- self.last_gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"]
-
- gripper_action = action[-1]
- obs, reward, terminated, truncated, info = self.env.step(action)
- gripper_penalty = self.reward(reward, gripper_action)
-
- info["discrete_penalty"] = gripper_penalty
-
- return obs, reward, terminated, truncated, info
-
- def reset(self, **kwargs):
- """
- Reset the environment and penalty tracking.
-
- Args:
- **kwargs: Keyword arguments passed to the wrapped environment's reset.
-
- Returns:
- The initial observation and info with gripper penalty initialized.
- """
- self.last_gripper_state = None
- obs, info = super().reset(**kwargs)
- info["gripper_penalty"] = 0.0
- return obs, info
-
-
-class GripperActionWrapper(gym.ActionWrapper):
- """
- Wrapper that processes gripper control commands.
-
- This wrapper quantizes and processes gripper commands, adding a sleep time between
- consecutive gripper actions to prevent rapid toggling.
- """
-
- def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0):
- """
- Initialize the gripper action wrapper.
-
- Args:
- env: The environment to wrap.
- quantization_threshold: Threshold below which gripper commands are quantized to zero.
- gripper_sleep: Minimum time in seconds between consecutive gripper commands.
- """
- super().__init__(env)
- self.quantization_threshold = quantization_threshold
- self.gripper_sleep = gripper_sleep
- self.last_gripper_action_time = 0.0
- self.last_gripper_action = None
-
- def action(self, action):
- """
- Process gripper commands in the action.
-
- Args:
- action: The original action from the agent.
-
- Returns:
- Modified action with processed gripper command.
- """
- if self.gripper_sleep > 0.0:
- if (
- self.last_gripper_action is not None
- and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep
- ):
- action[-1] = self.last_gripper_action
- else:
- self.last_gripper_action_time = time.perf_counter()
- self.last_gripper_action = action[-1]
-
- gripper_command = action[-1]
- # Gripper actions are between 0, 2
- # we want to quantize them to -1, 0 or 1
- gripper_command = gripper_command - 1.0
-
- if self.quantization_threshold is not None:
- # Quantize gripper command to -1, 0 or 1
- gripper_command = (
- np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
- )
- gripper_command = gripper_command * self.unwrapped.robot.config.max_gripper_pos
-
- gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"]
-
- gripper_action_value = np.clip(
- gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos
- )
- action[-1] = gripper_action_value.item()
- return action
-
- def reset(self, **kwargs):
- """
- Reset the gripper action tracking.
-
- Args:
- **kwargs: Keyword arguments passed to the wrapped environment's reset.
-
- Returns:
- The initial observation and info.
- """
- obs, info = super().reset(**kwargs)
- self.last_gripper_action_time = 0.0
- self.last_gripper_action = None
- return obs, info
-
-
-class EEObservationWrapper(gym.ObservationWrapper):
- """
- Wrapper that adds end-effector pose information to observations.
-
- This wrapper computes the end-effector pose using forward kinematics
- and adds it to the observation space.
- """
-
- def __init__(self, env, ee_pose_limits):
- """
- Initialize the end-effector observation wrapper.
-
- Args:
- env: The environment to wrap.
- ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose.
- """
- super().__init__(env)
-
- # Extend observation space to include end effector pose
- prev_space = self.observation_space["observation.state"]
-
- self.observation_space["observation.state"] = gym.spaces.Box(
- low=np.concatenate([prev_space.low, ee_pose_limits["min"]]),
- high=np.concatenate([prev_space.high, ee_pose_limits["max"]]),
- shape=(prev_space.shape[0] + 3,),
- dtype=np.float32,
- )
-
- self.kinematics = RobotKinematics(
- urdf_path=env.unwrapped.robot.config.urdf_path,
- target_frame_name=env.unwrapped.robot.config.target_frame_name,
- )
-
- def observation(self, observation):
- """
- Add end-effector pose to the observation.
-
- Args:
- observation: Original observation from the environment.
-
- Returns:
- Enhanced observation with end-effector pose information.
- """
- current_joint_pos = self.unwrapped.current_observation["agent_pos"]
-
- current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos)[:3, 3]
- observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos], -1)
- return observation
-
-
-###########################################################
-# Wrappers related to human intervention and input devices
-###########################################################
-
-
-class BaseLeaderControlWrapper(gym.Wrapper):
- """
- Base class for leader-follower robot control wrappers.
-
- This wrapper enables human intervention through a leader-follower robot setup,
- where the human can control a leader robot to guide the follower robot's movements.
- """
-
- def __init__(
- self,
- env,
- teleop_device,
- end_effector_step_sizes,
- use_geared_leader_arm: bool = False,
- use_gripper=False,
- ):
- """
- Initialize the base leader control wrapper.
-
- Args:
- env: The environment to wrap.
- teleop_device: The teleoperation device.
- use_geared_leader_arm: Whether to use a geared leader arm setup.
- use_gripper: Whether to include gripper control.
- """
- super().__init__(env)
- self.robot_leader = teleop_device
- self.robot_follower = env.unwrapped.robot
- self.use_geared_leader_arm = use_geared_leader_arm
- self.use_gripper: bool = use_gripper
- self.end_effector_step_sizes = np.array(list(end_effector_step_sizes.values()))
-
- # Set up keyboard event tracking
- self._init_keyboard_events()
- self.event_lock = Lock() # Thread-safe access to events
-
- # Initialize robot control
- self.kinematics = RobotKinematics(
- urdf_path=env.unwrapped.robot.config.urdf_path,
- target_frame_name=env.unwrapped.robot.config.target_frame_name,
- )
- self.leader_torque_enabled = True
- self.prev_leader_gripper = None
-
- # Configure leader arm
- # NOTE: Lower the gains of leader arm for automatic take-over
- # With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot
- # With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled
- # Default value for P_coeff is 32
- self.robot_leader.bus.sync_write("Torque_Enable", 1)
- for motor in self.robot_leader.bus.motors:
- self.robot_leader.bus.write("P_Coefficient", motor, 16)
- self.robot_leader.bus.write("I_Coefficient", motor, 0)
- self.robot_leader.bus.write("D_Coefficient", motor, 16)
-
- self.leader_tracking_error_queue = deque(maxlen=4)
- self._init_keyboard_listener()
-
- def _init_keyboard_events(self):
- """
- Initialize the keyboard events dictionary.
-
- This method sets up tracking for keyboard events used for intervention control.
- It should be overridden in subclasses to add additional events.
- """
- self.keyboard_events = {
- "episode_success": False,
- "episode_end": False,
- "rerecord_episode": False,
- }
-
- def _handle_key_press(self, key, keyboard_device):
- """
- Handle key press events.
-
- Args:
- key: The key that was pressed.
- keyboard: The keyboard module with key definitions.
-
- This method should be overridden in subclasses for additional key handling.
- """
- try:
- if key == keyboard_device.Key.esc:
- self.keyboard_events["episode_end"] = True
- return
- if key == keyboard_device.Key.left:
- self.keyboard_events["rerecord_episode"] = True
- return
- if hasattr(key, "char") and key.char == "s":
- logging.info("Key 's' pressed. Episode success triggered.")
- self.keyboard_events["episode_success"] = True
- return
- except Exception as e:
- logging.error(f"Error handling key press: {e}")
-
- def _init_keyboard_listener(self):
- """
- Initialize the keyboard listener for intervention control.
-
- This method sets up keyboard event handling if not in headless mode.
- """
- from pynput import keyboard as keyboard_device
-
- def on_press(key):
- with self.event_lock:
- self._handle_key_press(key, keyboard_device)
-
- self.listener = keyboard_device.Listener(on_press=on_press)
- self.listener.start()
-
- def _check_intervention(self):
- """
- Check if human intervention is needed.
-
- Returns:
- Boolean indicating whether intervention is needed.
-
- This method should be overridden in subclasses with specific intervention logic.
- """
- return False
-
- def _handle_intervention(self, action):
- """
- Process actions during intervention mode.
-
- Args:
- action: The original action from the agent.
-
- Returns:
- Tuple of (modified_action, intervention_action).
- """
- if self.leader_torque_enabled:
- self.robot_leader.bus.sync_write("Torque_Enable", 0)
- self.leader_torque_enabled = False
-
- leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position")
- follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position")
-
- leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict])
- follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict])
-
- self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - leader_pos[:-1]))
-
- # [:3, 3] Last column of the transformation matrix corresponds to the xyz translation
- leader_ee = self.kinematics.forward_kinematics(leader_pos)[:3, 3]
- follower_ee = self.kinematics.forward_kinematics(follower_pos)[:3, 3]
-
- action = np.clip(leader_ee - follower_ee, -self.end_effector_step_sizes, self.end_effector_step_sizes)
- # Normalize the action to the range [-1, 1]
- action = action / self.end_effector_step_sizes
-
- if self.use_gripper:
- if self.prev_leader_gripper is None:
- self.prev_leader_gripper = np.clip(
- leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos
- )
-
- # Get gripper action delta based on leader pose
- leader_gripper = leader_pos[-1]
- gripper_delta = leader_gripper - self.prev_leader_gripper
-
- # Normalize by max angle and quantize to {0,1,2}
- normalized_delta = gripper_delta / self.robot_follower.config.max_gripper_pos
- if normalized_delta >= 0.3:
- gripper_action = 2
- elif normalized_delta <= 0.1:
- gripper_action = 0
- else:
- gripper_action = 1
-
- action = np.append(action, gripper_action)
-
- return action
-
- def _handle_leader_teleoperation(self):
- """
- Handle leader teleoperation in non-intervention mode.
-
- This method synchronizes the leader robot position with the follower.
- """
-
- prev_leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position")
- prev_leader_pos = np.array(
- [prev_leader_pos_dict[name] for name in prev_leader_pos_dict], dtype=np.float32
- )
-
- if not self.leader_torque_enabled:
- self.robot_leader.bus.sync_write("Torque_Enable", 1)
- self.leader_torque_enabled = True
-
- follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position")
- follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32)
-
- goal_pos = {f"{motor}": follower_pos[i] for i, motor in enumerate(self.robot_leader.bus.motors)}
- self.robot_leader.bus.sync_write("Goal_Position", goal_pos)
-
- self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - prev_leader_pos[:-1]))
-
- def step(self, action):
- """
- Execute a step with possible human intervention.
-
- Args:
- action: The action to take in the environment.
-
- Returns:
- Tuple of (observation, reward, terminated, truncated, info).
- """
- is_intervention = self._check_intervention()
-
- # NOTE:
- if is_intervention:
- action = self._handle_intervention(action)
- else:
- self._handle_leader_teleoperation()
-
- # NOTE:
- obs, reward, terminated, truncated, info = self.env.step(action)
-
- if isinstance(action, np.ndarray):
- action = torch.from_numpy(action)
-
- # Add intervention info
- info["is_intervention"] = is_intervention
- info["action_intervention"] = action
-
- self.prev_leader_gripper = np.clip(
- self.robot_leader.bus.sync_read("Present_Position")["gripper"],
- 0,
- self.robot_follower.config.max_gripper_pos,
- )
-
- # Check for success or manual termination
- success = self.keyboard_events["episode_success"]
- terminated = terminated or self.keyboard_events["episode_end"] or success
-
- if success:
- reward = 1.0
- logging.info("Episode ended successfully with reward 1.0")
-
- return obs, reward, terminated, truncated, info
-
- def reset(self, **kwargs):
- """
- Reset the environment and intervention state.
-
- Args:
- **kwargs: Keyword arguments passed to the wrapped environment's reset.
-
- Returns:
- The initial observation and info.
- """
- self.keyboard_events = dict.fromkeys(self.keyboard_events, False)
- self.leader_tracking_error_queue.clear()
- return super().reset(**kwargs)
-
- def close(self):
- """
- Clean up resources, including stopping keyboard listener.
-
- Returns:
- Result of closing the wrapped environment.
- """
- if hasattr(self, "listener") and self.listener is not None:
- self.listener.stop()
- return self.env.close()
-
-
-class GearedLeaderControlWrapper(BaseLeaderControlWrapper):
- """
- Wrapper that enables manual intervention via keyboard.
-
- This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling
- of human intervention mode with keyboard controls.
- """
-
- def _init_keyboard_events(self):
- """
- Initialize keyboard events including human intervention flag.
-
- Extends the base class dictionary with an additional flag for tracking
- intervention state toggled by keyboard.
- """
- super()._init_keyboard_events()
- self.keyboard_events["human_intervention_step"] = False
-
- def _handle_key_press(self, key, keyboard_device):
- """
- Handle key presses including space for intervention toggle.
-
- Args:
- key: The key that was pressed.
- keyboard: The keyboard module with key definitions.
-
- Extends the base handler to respond to space key for toggling intervention.
- """
- super()._handle_key_press(key, keyboard_device)
- if key == keyboard_device.Key.space:
- if not self.keyboard_events["human_intervention_step"]:
- logging.info(
- "Space key pressed. Human intervention required.\n"
- "Place the leader in similar pose to the follower and press space again."
- )
- self.keyboard_events["human_intervention_step"] = True
- log_say("Human intervention step.", play_sounds=True)
- else:
- self.keyboard_events["human_intervention_step"] = False
- logging.info("Space key pressed for a second time.\nContinuing with policy actions.")
- log_say("Continuing with policy actions.", play_sounds=True)
-
- def _check_intervention(self):
- """
- Check if human intervention is active based on keyboard toggle.
-
- Returns:
- Boolean indicating whether intervention mode is active.
- """
- return self.keyboard_events["human_intervention_step"]
-
-
-class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper):
- """
- Wrapper with automatic intervention based on error thresholds.
-
- This wrapper monitors the error between leader and follower positions
- and automatically triggers intervention when error exceeds thresholds.
- """
-
- def __init__(
- self,
- env,
- teleop_device,
- end_effector_step_sizes,
- use_gripper=False,
- intervention_threshold=10.0,
- release_threshold=1e-2,
- ):
- """
- Initialize the automatic intervention wrapper.
-
- Args:
- env: The environment to wrap.
- teleop_device: The teleoperation device.
- use_gripper: Whether to include gripper control.
- intervention_threshold: Error threshold to trigger intervention.
- release_threshold: Error threshold to release intervention.
- queue_size: Number of error measurements to track for smoothing.
- """
- super().__init__(env, teleop_device, end_effector_step_sizes, use_gripper=use_gripper)
-
- # Error tracking parameters
- self.intervention_threshold = intervention_threshold # Threshold to trigger intervention
- self.release_threshold = release_threshold # Threshold to release intervention
- self.is_intervention_active = False
- self.start_time = time.perf_counter()
-
- def _check_intervention(self):
- """
- Determine if intervention should occur based on the rate of change of leader-follower error in end_effector space.
-
- This method monitors the rate of change of leader-follower error in end_effector space
- and automatically triggers intervention when the rate of change exceeds
- the intervention threshold, releasing when it falls below the release threshold.
-
- Returns:
- Boolean indicating whether intervention should be active.
- """
-
- # Condition for starting the intervention
- # If the error in teleoperation is too high, that means the a user has grasped the leader robot and he wants to take over
- if (
- not self.is_intervention_active
- and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen
- and np.var(list(self.leader_tracking_error_queue)[-2:]) > self.intervention_threshold
- ):
- self.is_intervention_active = True
- self.leader_tracking_error_queue.clear()
- log_say("Intervention started", play_sounds=True)
- return True
-
- # Track the error over time in leader_tracking_error_queue
- # If the variance of the tracking error is too low, that means the user has let go of the leader robot and the intervention is over
- if (
- self.is_intervention_active
- and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen
- and np.var(self.leader_tracking_error_queue) < self.release_threshold
- ):
- self.is_intervention_active = False
- self.leader_tracking_error_queue.clear()
- log_say("Intervention ended", play_sounds=True)
- return False
-
- # If not change has happened that merits a change in the intervention state, return the current state
- return self.is_intervention_active
-
- def reset(self, **kwargs):
- """
- Reset error tracking on environment reset.
-
- Args:
- **kwargs: Keyword arguments passed to the wrapped environment's reset.
-
- Returns:
- The initial observation and info.
- """
- self.is_intervention_active = False
- return super().reset(**kwargs)
-
-
-class GamepadControlWrapper(gym.Wrapper):
- """
- Wrapper that allows controlling a gym environment with a gamepad.
-
- This wrapper intercepts the step method and allows human input via gamepad
- to override the agent's actions when desired.
- """
-
- def __init__(
- self,
- env,
- teleop_device, # Accepts an instantiated teleoperator
- use_gripper=False, # This should align with teleop_device's config
- auto_reset=False,
- ):
- """
- Initialize the gamepad controller wrapper.
-
- Args:
- env: The environment to wrap.
- teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop).
- use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper).
- auto_reset: Whether to auto reset the environment when episode ends.
- """
- super().__init__(env)
-
- self.teleop_device = teleop_device
- # Ensure the teleop_device is connected if it has a connect method
- if hasattr(self.teleop_device, "connect") and not self.teleop_device.is_connected:
- self.teleop_device.connect()
-
- # self.controller attribute is removed
-
- self.auto_reset = auto_reset
- # use_gripper from args should ideally match teleop_device.config.use_gripper
- # For now, we use the one passed, but it can lead to inconsistency if not set correctly from config
- self.use_gripper = use_gripper
-
- logging.info("Gamepad control wrapper initialized with provided teleop_device.")
- print(
- "Gamepad controls (managed by the provided teleop_device - specific button mappings might vary):"
- )
- print(" Left analog stick: Move in X-Y plane")
- print(" Right analog stick: Move in Z axis (up/down)")
- print(" X/Square button: End episode (FAILURE)")
- print(" Y/Triangle button: End episode (SUCCESS)")
- print(" B/Circle button: Exit program")
-
- def get_teleop_commands(
- self,
- ) -> tuple[bool, np.ndarray, bool, bool, bool]:
- """
- Get the current action from the gamepad if any input is active.
-
- Returns:
- Tuple containing:
- - is_active: Whether gamepad input is active (from teleop_device.gamepad.should_intervene())
- - action: The action derived from gamepad input (from teleop_device.get_action())
- - terminate_episode: Whether episode termination was requested
- - success: Whether episode success was signaled
- - rerecord_episode: Whether episode rerecording was requested
- """
- if not hasattr(self.teleop_device, "gamepad") or self.teleop_device.gamepad is None:
- raise AttributeError(
- "teleop_device does not have a 'gamepad' attribute or it is None. Expected for GamepadControlWrapper."
- )
-
- # Get status flags from the underlying gamepad controller within the teleop_device
- self.teleop_device.gamepad.update() # Ensure gamepad state is fresh
- intervention_is_active = self.teleop_device.gamepad.should_intervene()
- episode_end_status = self.teleop_device.gamepad.get_episode_end_status()
-
- terminate_episode = episode_end_status is not None
- success = episode_end_status == "success"
- rerecord_episode = episode_end_status == "rerecord_episode"
-
- # Get the action dictionary from the teleop_device
- action_dict = self.teleop_device.get_action()
-
- # Convert action_dict to numpy array based on expected structure
- # Order: delta_x, delta_y, delta_z, gripper (if use_gripper)
- action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]]
- if self.use_gripper:
- # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open)
- # This needs to be consistent with what EEActionWrapper expects if it's used downstream
- # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open)
- # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility.
- gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present
- action_list.append(float(gripper_val))
-
- gamepad_action_np = np.array(action_list, dtype=np.float32)
-
- return (
- intervention_is_active,
- gamepad_action_np,
- terminate_episode,
- success,
- rerecord_episode,
- )
-
- def step(self, action):
- """
- Step the environment, using gamepad input to override actions when active.
-
- Args:
- action: Original action from agent.
-
- Returns:
- Tuple of (observation, reward, terminated, truncated, info).
- """
- # Get gamepad state and action
- (
- is_intervention,
- gamepad_action,
- terminate_episode,
- success,
- rerecord_episode,
- ) = self.get_teleop_commands()
-
- # Update episode ending state if requested
- if terminate_episode:
- logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}")
-
- # Only override the action if gamepad is active
- action = gamepad_action if is_intervention else action
-
- # Step the environment
- obs, reward, terminated, truncated, info = self.env.step(action)
-
- # Add episode ending if requested via gamepad
- terminated = terminated or truncated or terminate_episode
-
- if success:
- reward = 1.0
- logging.info("Episode ended successfully with reward 1.0")
-
- if isinstance(action, np.ndarray):
- action = torch.from_numpy(action)
-
- info["is_intervention"] = is_intervention
- # The original `BaseLeaderControlWrapper` puts `action_intervention` in info.
- # For Gamepad, if intervention, `gamepad_action` is the intervention.
- # If not intervention, policy's action is `action`.
- # For consistency, let's store the *human's* action if intervention occurred.
- info["action_intervention"] = action
-
- info["rerecord_episode"] = rerecord_episode
-
- # If episode ended, reset the state
- if terminated or truncated:
- # Add success/failure information to info dict
- info["next.success"] = success
-
- # Auto reset if configured
- if self.auto_reset:
- obs, reset_info = self.reset()
- info.update(reset_info)
-
- return obs, reward, terminated, truncated, info
-
- def close(self):
- """
- Clean up resources when environment closes.
-
- Returns:
- Result of closing the wrapped environment.
- """
- if hasattr(self.teleop_device, "disconnect"):
- self.teleop_device.disconnect()
-
- # Call the parent close method
- return self.env.close()
-
-
-class KeyboardControlWrapper(GamepadControlWrapper):
- """
- Wrapper that allows controlling a gym environment with a keyboard.
-
- This wrapper intercepts the step method and allows human input via keyboard
- to override the agent's actions when desired.
-
- Inherits from GamepadControlWrapper to avoid code duplication.
- """
-
- def __init__(
- self,
- env,
- teleop_device, # Accepts an instantiated teleoperator
- use_gripper=False, # This should align with teleop_device's config
- auto_reset=False,
- ):
- """
- Initialize the gamepad controller wrapper.
-
- Args:
- env: The environment to wrap.
- teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop).
- use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper).
- auto_reset: Whether to auto reset the environment when episode ends.
- """
- super().__init__(env, teleop_device, use_gripper, auto_reset)
-
- self.is_intervention_active = False
-
- logging.info("Keyboard control wrapper initialized with provided teleop_device.")
- print("Keyboard controls:")
- print(" Arrow keys: Move in X-Y plane")
- print(" Shift and Shift_R: Move in Z axis")
- print(" Right Ctrl and Left Ctrl: Open and close gripper")
- print(" f: End episode with FAILURE")
- print(" s: End episode with SUCCESS")
- print(" r: End episode with RERECORD")
- print(" i: Start/Stop Intervention")
-
- def get_teleop_commands(
- self,
- ) -> tuple[bool, np.ndarray, bool, bool, bool]:
- action_dict = self.teleop_device.get_action()
- episode_end_status = None
-
- # Unroll the misc_keys_queue to check for events related to intervention, episode success, etc.
- while not self.teleop_device.misc_keys_queue.empty():
- key = self.teleop_device.misc_keys_queue.get()
- if key == "i":
- self.is_intervention_active = not self.is_intervention_active
- elif key == "f":
- episode_end_status = "failure"
- elif key == "s":
- episode_end_status = "success"
- elif key == "r":
- episode_end_status = "rerecord_episode"
-
- terminate_episode = episode_end_status is not None
- success = episode_end_status == "success"
- rerecord_episode = episode_end_status == "rerecord_episode"
-
- # Convert action_dict to numpy array based on expected structure
- # Order: delta_x, delta_y, delta_z, gripper (if use_gripper)
- action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]]
- if self.use_gripper:
- # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open)
- # This needs to be consistent with what EEActionWrapper expects if it's used downstream
- # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open)
- # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility.
- gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present
- action_list.append(float(gripper_val))
-
- gamepad_action_np = np.array(action_list, dtype=np.float32)
-
- return (
- self.is_intervention_active,
- gamepad_action_np,
- terminate_episode,
- success,
- rerecord_episode,
- )
-
-
-class GymHilDeviceWrapper(gym.Wrapper):
- def __init__(self, env, device="cpu"):
- super().__init__(env)
- self.device = device
-
- def step(self, action):
- obs, reward, terminated, truncated, info = self.env.step(action)
- for k in obs:
- obs[k] = obs[k].to(self.device)
- if "action_intervention" in info:
- # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
- info["action_intervention"] = info["action_intervention"].astype(np.float32)
- info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
- return obs, reward, terminated, truncated, info
-
- def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
- obs, info = self.env.reset(seed=seed, options=options)
- for k in obs:
- obs[k] = obs[k].to(self.device)
- if "action_intervention" in info:
- # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
- info["action_intervention"] = info["action_intervention"].astype(np.float32)
- info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
- return obs, info
-
-
-class GymHilObservationProcessorWrapper(gym.ObservationWrapper):
- def __init__(self, env: gym.Env):
- super().__init__(env)
- prev_space = self.observation_space
- new_space = {}
-
- for key in prev_space:
- if "pixels" in key:
- for k in prev_space["pixels"]:
- new_space[f"observation.images.{k}"] = gym.spaces.Box(
- 0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8
- )
-
- if key == "agent_pos":
- new_space["observation.state"] = prev_space["agent_pos"]
-
- self.observation_space = gym.spaces.Dict(new_space)
-
- def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
- return preprocess_observation(observation)
-
-
-###########################################################
-# Factory functions
-###########################################################
-
-
-def make_robot_env(cfg: EnvConfig) -> gym.Env:
- """
- Factory function to create a robot environment.
-
- This function builds a robot environment with all necessary wrappers
- based on the provided configuration.
-
- Args:
- cfg: Configuration object containing environment parameters.
-
- Returns:
- A gym environment with all necessary wrappers applied.
- """
- if cfg.type == "hil":
- import gym_hil # noqa: F401
-
- # TODO (azouitine)
- env = gym.make(
- f"gym_hil/{cfg.task}",
- image_obs=True,
- render_mode="human",
- use_gripper=cfg.wrapper.use_gripper,
- gripper_penalty=cfg.wrapper.gripper_penalty,
- )
- env = GymHilObservationProcessorWrapper(env=env)
- env = GymHilDeviceWrapper(env=env, device=cfg.device)
- env = BatchCompatibleWrapper(env=env)
- env = TorchActionWrapper(env=env, device=cfg.device)
- return env
-
- if not hasattr(cfg, "robot") or not hasattr(cfg, "teleop"):
- raise ValueError(
- "Configuration for 'gym_manipulator' must be HILSerlRobotEnvConfig with robot and teleop."
- )
-
- if cfg.robot is None:
- raise ValueError("RobotConfig (cfg.robot) must be provided for gym_manipulator environment.")
- robot = make_robot_from_config(cfg.robot)
- teleop_device = make_teleoperator_from_config(cfg.teleop)
- teleop_device.connect()
-
- # Create base environment
- env = RobotEnv(
- robot=robot,
- use_gripper=cfg.wrapper.use_gripper,
- display_cameras=cfg.wrapper.display_cameras if cfg.wrapper else False,
- )
-
- # Add observation and image processing
- if cfg.wrapper:
- if cfg.wrapper.add_joint_velocity_to_observation:
- env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
- if cfg.wrapper.add_current_to_observation:
- env = AddCurrentToObservation(env=env)
- if cfg.wrapper.add_ee_pose_to_observation:
- env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds)
-
- env = ConvertToLeRobotObservation(env=env, device=cfg.device)
-
- if cfg.wrapper and cfg.wrapper.crop_params_dict is not None:
- env = ImageCropResizeWrapper(
- env=env,
- crop_params_dict=cfg.wrapper.crop_params_dict,
- resize_size=cfg.wrapper.resize_size,
- )
-
- # Add reward computation and control wrappers
- reward_classifier = init_reward_classifier(cfg)
- if reward_classifier is not None:
- env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
-
- env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
- if cfg.wrapper.use_gripper and cfg.wrapper.gripper_penalty is not None:
- env = GripperPenaltyWrapper(
- env=env,
- penalty=cfg.wrapper.gripper_penalty,
- )
-
- # Control mode specific wrappers
- control_mode = cfg.wrapper.control_mode
- if control_mode == "gamepad":
- assert isinstance(teleop_device, GamepadTeleop), (
- "teleop_device must be an instance of GamepadTeleop for gamepad control mode"
- )
- env = GamepadControlWrapper(
- env=env,
- teleop_device=teleop_device,
- use_gripper=cfg.wrapper.use_gripper,
- )
- elif control_mode == "keyboard_ee":
- assert isinstance(teleop_device, KeyboardEndEffectorTeleop), (
- "teleop_device must be an instance of KeyboardEndEffectorTeleop for keyboard control mode"
- )
- env = KeyboardControlWrapper(
- env=env,
- teleop_device=teleop_device,
- use_gripper=cfg.wrapper.use_gripper,
- )
- elif control_mode == "leader":
- env = GearedLeaderControlWrapper(
- env=env,
- teleop_device=teleop_device,
- end_effector_step_sizes=cfg.robot.end_effector_step_sizes,
- use_gripper=cfg.wrapper.use_gripper,
- )
- elif control_mode == "leader_automatic":
- env = GearedLeaderAutomaticControlWrapper(
- env=env,
- teleop_device=teleop_device,
- end_effector_step_sizes=cfg.robot.end_effector_step_sizes,
- use_gripper=cfg.wrapper.use_gripper,
- )
- else:
- raise ValueError(f"Invalid control mode: {control_mode}")
-
- env = ResetWrapper(
- env=env,
- reset_pose=cfg.wrapper.fixed_reset_joint_positions,
- reset_time_s=cfg.wrapper.reset_time_s,
- )
-
- env = BatchCompatibleWrapper(env=env)
- env = TorchActionWrapper(env=env, device=cfg.device)
-
- return env
-
-
-def init_reward_classifier(cfg):
- """
- Load a reward classifier policy from a pretrained path if configured.
-
- Args:
- cfg: The environment configuration containing classifier paths.
-
- Returns:
- The loaded classifier model or None if not configured.
- """
- if cfg.reward_classifier_pretrained_path is None:
- return None
-
- from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
-
- # Get device from config or default to CUDA
- device = getattr(cfg, "device", "cpu")
-
- # Load the classifier directly using from_pretrained
- classifier = Classifier.from_pretrained(
- pretrained_name_or_path=cfg.reward_classifier_pretrained_path,
- )
-
- # Ensure model is on the correct device
- classifier.to(device)
- classifier.eval() # Set to evaluation mode
-
- return classifier
-
-
-###########################################################
-# Record and replay functions
-###########################################################
-
-
-def record_dataset(env, policy, cfg):
- """
- Record a dataset of robot interactions using either a policy or teleop.
-
- This function runs episodes in the environment and records the observations,
- actions, and results for dataset creation.
-
- Args:
- env: The environment to record from.
- policy: Optional policy to generate actions (if None, uses teleop).
- cfg: Configuration object containing recording parameters like:
- - repo_id: Repository ID for dataset storage
- - dataset_root: Local root directory for dataset
- - num_episodes: Number of episodes to record
- - fps: Frames per second for recording
- - push_to_hub: Whether to push dataset to Hugging Face Hub
- - task: Name/description of the task being recorded
- - number_of_steps_after_success: Number of additional steps to continue recording after
- a success (reward=1) is detected. This helps collect
- more positive examples for reward classifier training.
- """
- from lerobot.datasets.lerobot_dataset import LeRobotDataset
-
- # Setup initial action (zero action if using teleop)
- action = env.action_space.sample() * 0.0
-
- action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"]
- if cfg.wrapper.use_gripper:
- action_names.append("gripper_delta")
-
- # Configure dataset features based on environment spaces
- features = {
- "observation.state": {
- "dtype": "float32",
- "shape": env.observation_space["observation.state"].shape,
- "names": None,
- },
- "action": {
- "dtype": "float32",
- "shape": (len(action_names),),
- "names": action_names,
- },
- "next.reward": {"dtype": "float32", "shape": (1,), "names": None},
- "next.done": {"dtype": "bool", "shape": (1,), "names": None},
- "complementary_info.discrete_penalty": {
- "dtype": "float32",
- "shape": (1,),
- "names": ["discrete_penalty"],
- },
- }
-
- # Add image features
- for key in env.observation_space:
- if "image" in key:
- features[key] = {
- "dtype": "video",
- "shape": env.observation_space[key].shape,
- "names": ["channels", "height", "width"],
- }
-
- # Create dataset
- dataset = LeRobotDataset.create(
- cfg.repo_id,
- cfg.fps,
- root=cfg.dataset_root,
- use_videos=True,
- image_writer_threads=4,
- image_writer_processes=0,
- features=features,
- )
-
- # Record episodes
- episode_index = 0
- recorded_action = None
- while episode_index < cfg.num_episodes:
- obs, _ = env.reset()
- start_episode_t = time.perf_counter()
- log_say(f"Recording episode {episode_index}", play_sounds=True)
-
- # Track success state collection
- success_detected = False
- success_steps_collected = 0
-
- # Run episode steps
- while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s:
- start_loop_t = time.perf_counter()
-
- # Get action from policy if available
- if cfg.pretrained_policy_name_or_path is not None:
- action = policy.select_action(obs)
-
- # Step environment
- obs, reward, terminated, truncated, info = env.step(action)
-
- # Check if episode needs to be rerecorded
- if info.get("rerecord_episode", False):
- break
-
- # For teleop, get action from intervention
- recorded_action = {
- "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action
- }
-
- # Process observation for dataset
- obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
-
- # Check if we've just detected success
- if reward == 1.0 and not success_detected:
- success_detected = True
- logging.info("Success detected! Collecting additional success states.")
-
- # Add frame to dataset - continue marking as success even during extra collection steps
- frame = {**obs_processed, **recorded_action}
-
- # If we're in the success collection phase, keep marking rewards as 1.0
- if success_detected:
- frame["next.reward"] = np.array([1.0], dtype=np.float32)
- else:
- frame["next.reward"] = np.array([reward], dtype=np.float32)
-
- # Only mark as done if we're truly done (reached end or collected enough success states)
- really_done = terminated or truncated
- if success_detected:
- success_steps_collected += 1
- really_done = success_steps_collected >= cfg.number_of_steps_after_success
-
- frame["next.done"] = np.array([really_done], dtype=bool)
- frame["complementary_info.discrete_penalty"] = torch.tensor(
- [info.get("discrete_penalty", 0.0)], dtype=torch.float32
- )
- dataset.add_frame(frame, task=cfg.task)
-
- # Maintain consistent timing
- if cfg.fps:
- dt_s = time.perf_counter() - start_loop_t
- busy_wait(1 / cfg.fps - dt_s)
-
- # Check if we should end the episode
- if (terminated or truncated) and not success_detected:
- # Regular termination without success
- break
- elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success:
- # We've collected enough success states
- logging.info(f"Collected {success_steps_collected} additional success states")
- break
-
- # Handle episode recording
- if info.get("rerecord_episode", False):
- dataset.clear_episode_buffer()
- logging.info(f"Re-recording episode {episode_index}")
- continue
-
- dataset.save_episode()
- episode_index += 1
-
- # Finalize dataset
- # dataset.consolidate(run_compute_stats=True)
- if cfg.push_to_hub:
- dataset.push_to_hub()
-
-
-def replay_episode(env, cfg):
- """
- Replay a recorded episode in the environment.
-
- This function loads actions from a previously recorded episode
- and executes them in the environment.
-
- Args:
- env: The environment to replay in.
- cfg: Configuration object containing replay parameters:
- - repo_id: Repository ID for dataset
- - dataset_root: Local root directory for dataset
- - episode: Episode ID to replay
- """
- from lerobot.datasets.lerobot_dataset import LeRobotDataset
-
- dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode])
- env.reset()
-
- actions = dataset.hf_dataset.select_columns("action")
-
- for idx in range(dataset.num_frames):
- start_episode_t = time.perf_counter()
-
- action = actions[idx]["action"]
- env.step(action)
-
- dt_s = time.perf_counter() - start_episode_t
- busy_wait(1 / 10 - dt_s)
-
-
-@parser.wrap()
-def main(cfg: EnvConfig):
- """Main entry point for the robot environment script.
-
- This function runs the robot environment in one of several modes
- based on the provided configuration.
-
- Args:
- cfg: Configuration object defining the run parameters,
- including mode (record, replay, random) and other settings.
- """
- env = make_robot_env(cfg)
-
- if cfg.mode == "record":
- policy = None
- if cfg.pretrained_policy_name_or_path is not None:
- from lerobot.policies.sac.modeling_sac import SACPolicy
-
- policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path)
- policy.to(cfg.device)
- policy.eval()
-
- record_dataset(
- env,
- policy=policy,
- cfg=cfg,
- )
- exit()
-
- if cfg.mode == "replay":
- replay_episode(
- env,
- cfg=cfg,
- )
- exit()
-
- env.reset()
-
- # Initialize the smoothed action as a random sample.
- smoothed_action = env.action_space.sample() * 0.0
-
- # Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
- # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
- alpha = 1.0
-
- num_episode = 0
- successes = []
- while num_episode < 10:
- start_loop_s = time.perf_counter()
- # Sample a new random action from the robot's action space.
- new_random_action = env.action_space.sample()
- # Update the smoothed action using an exponential moving average.
- smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
-
- # Execute the step: wrap the NumPy action in a torch tensor.
- obs, reward, terminated, truncated, info = env.step(smoothed_action)
- if terminated or truncated:
- successes.append(reward)
- env.reset()
- num_episode += 1
-
- dt_s = time.perf_counter() - start_loop_s
- busy_wait(1 / cfg.fps - dt_s)
-
- logging.info(f"Success after 20 steps {successes}")
- logging.info(f"success rate {sum(successes) / len(successes)}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/lerobot/scripts/visualize_dataset_html.py b/src/lerobot/scripts/visualize_dataset_html.py
deleted file mode 100644
index a722da603..000000000
--- a/src/lerobot/scripts/visualize_dataset_html.py
+++ /dev/null
@@ -1,482 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
-
-Note: The last frame of the episode doesnt always correspond to a final state.
-That's because our datasets are composed of transition from state to state up to
-the antepenultimate state associated to the ultimate action to arrive in the final state.
-However, there might not be a transition from a final state to another state.
-
-Note: This script aims to visualize the data used to train the neural networks.
-~What you see is what you get~. When visualizing image modality, it is often expected to observe
-lossly compression artifacts since these images have been decoded from compressed mp4 videos to
-save disk space. The compression factor applied has been tuned to not affect success rate.
-
-Example of usage:
-
-- Visualize data stored on a local machine:
-```bash
-local$ python -m lerobot.scripts.visualize_dataset_html \
- --repo-id lerobot/pusht
-
-local$ open http://localhost:9090
-```
-
-- Visualize data stored on a distant machine with a local viewer:
-```bash
-distant$ python -m lerobot.scripts.visualize_dataset_html \
- --repo-id lerobot/pusht
-
-local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
-local$ open http://localhost:9090
-```
-
-- Select episodes to visualize:
-```bash
-python -m lerobot.scripts.visualize_dataset_html \
- --repo-id lerobot/pusht \
- --episodes 7 3 5 1 4
-```
-"""
-
-import argparse
-import csv
-import json
-import logging
-import re
-import shutil
-import tempfile
-from io import StringIO
-from pathlib import Path
-
-import numpy as np
-import pandas as pd
-import requests
-from flask import Flask, redirect, render_template, request, url_for
-
-from lerobot import available_datasets
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.utils import IterableNamespace
-from lerobot.utils.utils import init_logging
-
-
-def run_server(
- dataset: LeRobotDataset | IterableNamespace | None,
- episodes: list[int] | None,
- host: str,
- port: str,
- static_folder: Path,
- template_folder: Path,
-):
- app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
- app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
-
- @app.route("/")
- def hommepage(dataset=dataset):
- if dataset:
- dataset_namespace, dataset_name = dataset.repo_id.split("/")
- return redirect(
- url_for(
- "show_episode",
- dataset_namespace=dataset_namespace,
- dataset_name=dataset_name,
- episode_id=0,
- )
- )
-
- dataset_param, episode_param = None, None
- all_params = request.args
- if "dataset" in all_params:
- dataset_param = all_params["dataset"]
- if "episode" in all_params:
- episode_param = int(all_params["episode"])
-
- if dataset_param:
- dataset_namespace, dataset_name = dataset_param.split("/")
- return redirect(
- url_for(
- "show_episode",
- dataset_namespace=dataset_namespace,
- dataset_name=dataset_name,
- episode_id=episode_param if episode_param is not None else 0,
- )
- )
-
- featured_datasets = [
- "lerobot/aloha_static_cups_open",
- "lerobot/columbia_cairlab_pusht_real",
- "lerobot/taco_play",
- ]
- return render_template(
- "visualize_dataset_homepage.html",
- featured_datasets=featured_datasets,
- lerobot_datasets=available_datasets,
- )
-
- @app.route("//")
- def show_first_episode(dataset_namespace, dataset_name):
- first_episode_id = 0
- return redirect(
- url_for(
- "show_episode",
- dataset_namespace=dataset_namespace,
- dataset_name=dataset_name,
- episode_id=first_episode_id,
- )
- )
-
- @app.route("///episode_")
- def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
- repo_id = f"{dataset_namespace}/{dataset_name}"
- try:
- if dataset is None:
- dataset = get_dataset_info(repo_id)
- except FileNotFoundError:
- return (
- "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
- 400,
- )
- dataset_version = (
- str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
- )
- match = re.search(r"v(\d+)\.", dataset_version)
- if match:
- major_version = int(match.group(1))
- if major_version < 2:
- return "Make sure to convert your LeRobotDataset to v2 & above."
-
- episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
- dataset_info = {
- "repo_id": f"{dataset_namespace}/{dataset_name}",
- "num_samples": dataset.num_frames
- if isinstance(dataset, LeRobotDataset)
- else dataset.total_frames,
- "num_episodes": dataset.num_episodes
- if isinstance(dataset, LeRobotDataset)
- else dataset.total_episodes,
- "fps": dataset.fps,
- }
- if isinstance(dataset, LeRobotDataset):
- video_paths = [
- dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
- ]
- videos_info = [
- {
- "url": url_for("static", filename=str(video_path).replace("\\", "/")),
- "filename": video_path.parent.name,
- }
- for video_path in video_paths
- ]
- tasks = dataset.meta.episodes[episode_id]["tasks"]
- else:
- video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
- videos_info = [
- {
- "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
- + dataset.video_path.format(
- episode_chunk=int(episode_id) // dataset.chunks_size,
- video_key=video_key,
- episode_index=episode_id,
- ),
- "filename": video_key,
- }
- for video_key in video_keys
- ]
-
- response = requests.get(
- f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
- )
- response.raise_for_status()
- # Split into lines and parse each line as JSON
- tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
-
- filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
- tasks = filtered_tasks_jsonl[0]["tasks"]
-
- videos_info[0]["language_instruction"] = tasks
-
- if episodes is None:
- episodes = list(
- range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
- )
-
- return render_template(
- "visualize_dataset_template.html",
- episode_id=episode_id,
- episodes=episodes,
- dataset_info=dataset_info,
- videos_info=videos_info,
- episode_data_csv_str=episode_data_csv_str,
- columns=columns,
- ignored_columns=ignored_columns,
- )
-
- app.run(host=host, port=port)
-
-
-def get_ep_csv_fname(episode_id: int):
- ep_csv_fname = f"episode_{episode_id}.csv"
- return ep_csv_fname
-
-
-def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
- """Get a csv str containing timeseries data of an episode (e.g. state and action).
- This file will be loaded by Dygraph javascript to plot data in real time."""
- columns = []
-
- selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
- selected_columns.remove("timestamp")
-
- ignored_columns = []
- for column_name in selected_columns:
- shape = dataset.features[column_name]["shape"]
- shape_dim = len(shape)
- if shape_dim > 1:
- selected_columns.remove(column_name)
- ignored_columns.append(column_name)
-
- # init header of csv with state and action names
- header = ["timestamp"]
-
- for column_name in selected_columns:
- dim_state = (
- dataset.meta.shapes[column_name][0]
- if isinstance(dataset, LeRobotDataset)
- else dataset.features[column_name].shape[0]
- )
-
- if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
- column_names = dataset.features[column_name]["names"]
- while not isinstance(column_names, list):
- column_names = list(column_names.values())[0]
- else:
- column_names = [f"{column_name}_{i}" for i in range(dim_state)]
- columns.append({"key": column_name, "value": column_names})
-
- header += column_names
-
- selected_columns.insert(0, "timestamp")
-
- if isinstance(dataset, LeRobotDataset):
- from_idx = dataset.episode_data_index["from"][episode_index]
- to_idx = dataset.episode_data_index["to"][episode_index]
- data = (
- dataset.hf_dataset.select(range(from_idx, to_idx))
- .select_columns(selected_columns)
- .with_format("pandas")
- )
- else:
- repo_id = dataset.repo_id
-
- url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
- episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
- )
- df = pd.read_parquet(url)
- data = df[selected_columns] # Select specific columns
-
- rows = np.hstack(
- (
- np.expand_dims(data["timestamp"], axis=1),
- *[np.vstack(data[col]) for col in selected_columns[1:]],
- )
- ).tolist()
-
- # Convert data to CSV string
- csv_buffer = StringIO()
- csv_writer = csv.writer(csv_buffer)
- # Write header
- csv_writer.writerow(header)
- # Write data rows
- csv_writer.writerows(rows)
- csv_string = csv_buffer.getvalue()
-
- return csv_string, columns, ignored_columns
-
-
-def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
- # get first frame of episode (hack to get video_path of the episode)
- first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
- return [
- dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
- for key in dataset.meta.video_keys
- ]
-
-
-def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
- # check if the dataset has language instructions
- if "language_instruction" not in dataset.features:
- return None
-
- # get first frame index
- first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
-
- language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
- # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
- # with the tf.tensor appearing in the string
- return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
-
-
-def get_dataset_info(repo_id: str) -> IterableNamespace:
- response = requests.get(
- f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
- )
- response.raise_for_status() # Raises an HTTPError for bad responses
- dataset_info = response.json()
- dataset_info["repo_id"] = repo_id
- return IterableNamespace(dataset_info)
-
-
-def visualize_dataset_html(
- dataset: LeRobotDataset | None,
- episodes: list[int] | None = None,
- output_dir: Path | None = None,
- serve: bool = True,
- host: str = "127.0.0.1",
- port: int = 9090,
- force_override: bool = False,
-) -> Path | None:
- init_logging()
-
- template_dir = Path(__file__).resolve().parent.parent / "templates"
-
- if output_dir is None:
- # Create a temporary directory that will be automatically cleaned up
- output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
-
- output_dir = Path(output_dir)
- if output_dir.exists():
- if force_override:
- shutil.rmtree(output_dir)
- else:
- logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
-
- output_dir.mkdir(parents=True, exist_ok=True)
-
- static_dir = output_dir / "static"
- static_dir.mkdir(parents=True, exist_ok=True)
-
- if dataset is None:
- if serve:
- run_server(
- dataset=None,
- episodes=None,
- host=host,
- port=port,
- static_folder=static_dir,
- template_folder=template_dir,
- )
- else:
- # Create a simlink from the dataset video folder containing mp4 files to the output directory
- # so that the http server can get access to the mp4 files.
- if isinstance(dataset, LeRobotDataset):
- ln_videos_dir = static_dir / "videos"
- if not ln_videos_dir.exists():
- ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix())
-
- if serve:
- run_server(dataset, episodes, host, port, static_dir, template_dir)
-
-
-def main():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--repo-id",
- type=str,
- default=None,
- help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
- )
- parser.add_argument(
- "--root",
- type=Path,
- default=None,
- help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
- )
- parser.add_argument(
- "--load-from-hf-hub",
- type=int,
- default=0,
- help="Load videos and parquet files from HF Hub rather than local system.",
- )
- parser.add_argument(
- "--episodes",
- type=int,
- nargs="*",
- default=None,
- help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
- )
- parser.add_argument(
- "--output-dir",
- type=Path,
- default=None,
- help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
- )
- parser.add_argument(
- "--serve",
- type=int,
- default=1,
- help="Launch web server.",
- )
- parser.add_argument(
- "--host",
- type=str,
- default="127.0.0.1",
- help="Web host used by the http server.",
- )
- parser.add_argument(
- "--port",
- type=int,
- default=9090,
- help="Web port used by the http server.",
- )
- parser.add_argument(
- "--force-override",
- type=int,
- default=0,
- help="Delete the output directory if it exists already.",
- )
-
- parser.add_argument(
- "--tolerance-s",
- type=float,
- default=1e-4,
- help=(
- "Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
- "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
- "If not given, defaults to 1e-4."
- ),
- )
-
- args = parser.parse_args()
- kwargs = vars(args)
- repo_id = kwargs.pop("repo_id")
- load_from_hf_hub = kwargs.pop("load_from_hf_hub")
- root = kwargs.pop("root")
- tolerance_s = kwargs.pop("tolerance_s")
-
- dataset = None
- if repo_id:
- dataset = (
- LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
- if not load_from_hf_hub
- else get_dataset_info(repo_id)
- )
-
- visualize_dataset_html(dataset, **vars(args))
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py
deleted file mode 100644
index e2819345b..000000000
--- a/src/lerobot/teleoperate.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Simple script to control a robot from teleoperation.
-
-Example:
-
-```shell
-python -m lerobot.teleoperate \
- --robot.type=so101_follower \
- --robot.port=/dev/tty.usbmodem58760431541 \
- --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
- --robot.id=black \
- --teleop.type=so101_leader \
- --teleop.port=/dev/tty.usbmodem58760431551 \
- --teleop.id=blue \
- --display_data=true
-```
-"""
-
-import logging
-import time
-from dataclasses import asdict, dataclass
-from pprint import pformat
-
-import draccus
-import rerun as rr
-
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
-from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
-from lerobot.robots import ( # noqa: F401
- Robot,
- RobotConfig,
- koch_follower,
- make_robot_from_config,
- so100_follower,
- so101_follower,
-)
-from lerobot.teleoperators import ( # noqa: F401
- Teleoperator,
- TeleoperatorConfig,
- gamepad,
- koch_leader,
- make_teleoperator_from_config,
- so100_leader,
- so101_leader,
-)
-from lerobot.utils.robot_utils import busy_wait
-from lerobot.utils.utils import init_logging, move_cursor_up
-from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
-
-
-@dataclass
-class TeleoperateConfig:
- # TODO: pepijn, steven: if more robots require multiple teleoperators (like lekiwi) its good to make this possibele in teleop.py and record.py with List[Teleoperator]
- teleop: TeleoperatorConfig
- robot: RobotConfig
- # Limit the maximum frames per second.
- fps: int = 60
- teleop_time_s: float | None = None
- # Display all cameras on screen
- display_data: bool = False
-
-
-def teleop_loop(
- teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None
-):
- display_len = max(len(key) for key in robot.action_features)
- start = time.perf_counter()
- while True:
- loop_start = time.perf_counter()
- action = teleop.get_action()
- if display_data:
- observation = robot.get_observation()
- log_rerun_data(observation, action)
-
- robot.send_action(action)
- dt_s = time.perf_counter() - loop_start
- busy_wait(1 / fps - dt_s)
-
- loop_s = time.perf_counter() - loop_start
-
- print("\n" + "-" * (display_len + 10))
- print(f"{'NAME':<{display_len}} | {'NORM':>7}")
- for motor, value in action.items():
- print(f"{motor:<{display_len}} | {value:>7.2f}")
- print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
-
- if duration is not None and time.perf_counter() - start >= duration:
- return
-
- move_cursor_up(len(action) + 5)
-
-
-@draccus.wrap()
-def teleoperate(cfg: TeleoperateConfig):
- init_logging()
- logging.info(pformat(asdict(cfg)))
- if cfg.display_data:
- _init_rerun(session_name="teleoperation")
-
- teleop = make_teleoperator_from_config(cfg.teleop)
- robot = make_robot_from_config(cfg.robot)
-
- teleop.connect()
- robot.connect()
-
- try:
- teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s)
- except KeyboardInterrupt:
- pass
- finally:
- if cfg.display_data:
- rr.rerun_shutdown()
- teleop.disconnect()
- robot.disconnect()
-
-
-if __name__ == "__main__":
- teleoperate()
diff --git a/src/lerobot/teleoperators/__init__.py b/src/lerobot/teleoperators/__init__.py
index ec93547f7..ee508dddb 100644
--- a/src/lerobot/teleoperators/__init__.py
+++ b/src/lerobot/teleoperators/__init__.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config import TeleoperatorConfig
from .teleoperator import Teleoperator
-from .utils import make_teleoperator_from_config
+from .utils import TeleopEvents, make_teleoperator_from_config
diff --git a/src/lerobot/teleoperators/bi_so100_leader/__init__.py b/src/lerobot/teleoperators/bi_so100_leader/__init__.py
new file mode 100644
index 000000000..34313a61e
--- /dev/null
+++ b/src/lerobot/teleoperators/bi_so100_leader/__init__.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .bi_so100_leader import BiSO100Leader
+from .config_bi_so100_leader import BiSO100LeaderConfig
diff --git a/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py b/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py
new file mode 100644
index 000000000..769669655
--- /dev/null
+++ b/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from functools import cached_property
+
+from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
+from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
+
+from ..teleoperator import Teleoperator
+from .config_bi_so100_leader import BiSO100LeaderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class BiSO100Leader(Teleoperator):
+ """
+ [Bimanual SO-100 Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
+ This bimanual leader arm can also be easily adapted to use SO-101 leader arms, just replace the SO100Leader class with SO101Leader and SO100LeaderConfig with SO101LeaderConfig.
+ """
+
+ config_class = BiSO100LeaderConfig
+ name = "bi_so100_leader"
+
+ def __init__(self, config: BiSO100LeaderConfig):
+ super().__init__(config)
+ self.config = config
+
+ left_arm_config = SO100LeaderConfig(
+ id=f"{config.id}_left" if config.id else None,
+ calibration_dir=config.calibration_dir,
+ port=config.left_arm_port,
+ )
+
+ right_arm_config = SO100LeaderConfig(
+ id=f"{config.id}_right" if config.id else None,
+ calibration_dir=config.calibration_dir,
+ port=config.right_arm_port,
+ )
+
+ self.left_arm = SO100Leader(left_arm_config)
+ self.right_arm = SO100Leader(right_arm_config)
+
+ @cached_property
+ def action_features(self) -> dict[str, type]:
+ return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | {
+ f"right_{motor}.pos": float for motor in self.right_arm.bus.motors
+ }
+
+ @cached_property
+ def feedback_features(self) -> dict[str, type]:
+ return {}
+
+ @property
+ def is_connected(self) -> bool:
+ return self.left_arm.is_connected and self.right_arm.is_connected
+
+ def connect(self, calibrate: bool = True) -> None:
+ self.left_arm.connect(calibrate)
+ self.right_arm.connect(calibrate)
+
+ @property
+ def is_calibrated(self) -> bool:
+ return self.left_arm.is_calibrated and self.right_arm.is_calibrated
+
+ def calibrate(self) -> None:
+ self.left_arm.calibrate()
+ self.right_arm.calibrate()
+
+ def configure(self) -> None:
+ self.left_arm.configure()
+ self.right_arm.configure()
+
+ def setup_motors(self) -> None:
+ self.left_arm.setup_motors()
+ self.right_arm.setup_motors()
+
+ def get_action(self) -> dict[str, float]:
+ action_dict = {}
+
+ # Add "left_" prefix
+ left_action = self.left_arm.get_action()
+ action_dict.update({f"left_{key}": value for key, value in left_action.items()})
+
+ # Add "right_" prefix
+ right_action = self.right_arm.get_action()
+ action_dict.update({f"right_{key}": value for key, value in right_action.items()})
+
+ return action_dict
+
+ def send_feedback(self, feedback: dict[str, float]) -> None:
+ # Remove "left_" prefix
+ left_feedback = {
+ key.removeprefix("left_"): value for key, value in feedback.items() if key.startswith("left_")
+ }
+ # Remove "right_" prefix
+ right_feedback = {
+ key.removeprefix("right_"): value for key, value in feedback.items() if key.startswith("right_")
+ }
+
+ if left_feedback:
+ self.left_arm.send_feedback(left_feedback)
+ if right_feedback:
+ self.right_arm.send_feedback(right_feedback)
+
+ def disconnect(self) -> None:
+ self.left_arm.disconnect()
+ self.right_arm.disconnect()
diff --git a/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py b/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py
new file mode 100644
index 000000000..117e09913
--- /dev/null
+++ b/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+from ..config import TeleoperatorConfig
+
+
+@TeleoperatorConfig.register_subclass("bi_so100_leader")
+@dataclass
+class BiSO100LeaderConfig(TeleoperatorConfig):
+ left_arm_port: str
+ right_arm_port: str
diff --git a/src/lerobot/teleoperators/gamepad/gamepad_utils.py b/src/lerobot/teleoperators/gamepad/gamepad_utils.py
index 9b62dc666..9f94b6746 100644
--- a/src/lerobot/teleoperators/gamepad/gamepad_utils.py
+++ b/src/lerobot/teleoperators/gamepad/gamepad_utils.py
@@ -16,6 +16,8 @@
import logging
+from ..utils import TeleopEvents
+
class InputController:
"""Base class for input controllers that generate motion deltas."""
@@ -50,10 +52,6 @@ class InputController:
"""Get the current movement deltas (dx, dy, dz) in meters."""
return 0.0, 0.0, 0.0
- def should_quit(self):
- """Return True if the user has requested to quit."""
- return not self.running
-
def update(self):
"""Update controller state - call this once per frame."""
pass
@@ -134,10 +132,10 @@ class KeyboardController(InputController):
return False
elif key == keyboard.Key.enter:
self.key_states["success"] = True
- self.episode_end_status = "success"
+ self.episode_end_status = TeleopEvents.SUCCESS
elif key == keyboard.Key.backspace:
self.key_states["failure"] = True
- self.episode_end_status = "failure"
+ self.episode_end_status = TeleopEvents.FAILURE
except AttributeError:
pass
@@ -196,14 +194,6 @@ class KeyboardController(InputController):
return delta_x, delta_y, delta_z
- def should_quit(self):
- """Return True if ESC was pressed."""
- return self.key_states["quit"]
-
- def should_save(self):
- """Return True if Enter was pressed (save episode)."""
- return self.key_states["success"] or self.key_states["failure"]
-
class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
@@ -255,13 +245,13 @@ class GamepadController(InputController):
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == 3:
- self.episode_end_status = "success"
+ self.episode_end_status = TeleopEvents.SUCCESS
# A button (1) for failure
elif event.button == 1:
- self.episode_end_status = "failure"
+ self.episode_end_status = TeleopEvents.FAILURE
# X button (0) for rerecord
elif event.button == 0:
- self.episode_end_status = "rerecord_episode"
+ self.episode_end_status = TeleopEvents.RERECORD_EPISODE
# RB button (6) for closing gripper
elif event.button == 6:
@@ -295,8 +285,8 @@ class GamepadController(InputController):
try:
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
- y_input = self.joystick.get_axis(0) # Left/Right
- x_input = self.joystick.get_axis(1) # Up/Down (often inverted)
+ y_input = self.joystick.get_axis(0) # Up/Down (often inverted)
+ x_input = self.joystick.get_axis(1) # Left/Right
# Right stick Y (typically axis 3 or 4)
z_input = self.joystick.get_axis(3) # Up/Down for Z
@@ -349,8 +339,6 @@ class GamepadControllerHID(InputController):
# Button states
self.buttons = {}
- self.quit_requested = False
- self.save_requested = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
@@ -451,11 +439,11 @@ class GamepadControllerHID(InputController):
# Check if X/Square button (bit 5) is pressed for failure
# Check if A/Cross button (bit 4) is pressed for rerecording
if buttons & 1 << 7:
- self.episode_end_status = "success"
+ self.episode_end_status = TeleopEvents.SUCCESS
elif buttons & 1 << 5:
- self.episode_end_status = "failure"
+ self.episode_end_status = TeleopEvents.FAILURE
elif buttons & 1 << 4:
- self.episode_end_status = "rerecord_episode"
+ self.episode_end_status = TeleopEvents.RERECORD_EPISODE
else:
self.episode_end_status = None
@@ -470,11 +458,3 @@ class GamepadControllerHID(InputController):
delta_z = -self.right_y * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
-
- def should_quit(self):
- """Return True if quit button was pressed."""
- return self.quit_requested
-
- def should_save(self):
- """Return True if save button was pressed."""
- return self.save_requested
diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py
index 98a0647e2..c7072f4a7 100644
--- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py
+++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py
@@ -21,6 +21,7 @@ from typing import Any
import numpy as np
from ..teleoperator import Teleoperator
+from ..utils import TeleopEvents
from .configuration_gamepad import GamepadTeleopConfig
@@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator):
return action_dict
+ def get_teleop_events(self) -> dict[str, Any]:
+ """
+ Get extra control events from the gamepad such as intervention status,
+ episode termination, success indicators, etc.
+
+ Returns:
+ Dictionary containing:
+ - is_intervention: bool - Whether human is currently intervening
+ - terminate_episode: bool - Whether to terminate the current episode
+ - success: bool - Whether the episode was successful
+ - rerecord_episode: bool - Whether to rerecord the episode
+ """
+ if self.gamepad is None:
+ return {
+ TeleopEvents.IS_INTERVENTION: False,
+ TeleopEvents.TERMINATE_EPISODE: False,
+ TeleopEvents.SUCCESS: False,
+ TeleopEvents.RERECORD_EPISODE: False,
+ }
+
+ # Update gamepad state to get fresh inputs
+ self.gamepad.update()
+
+ # Check if intervention is active
+ is_intervention = self.gamepad.should_intervene()
+
+ # Get episode end status
+ episode_end_status = self.gamepad.get_episode_end_status()
+ terminate_episode = episode_end_status in [
+ TeleopEvents.RERECORD_EPISODE,
+ TeleopEvents.FAILURE,
+ ]
+ success = episode_end_status == TeleopEvents.SUCCESS
+ rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE
+
+ return {
+ TeleopEvents.IS_INTERVENTION: is_intervention,
+ TeleopEvents.TERMINATE_EPISODE: terminate_episode,
+ TeleopEvents.SUCCESS: success,
+ TeleopEvents.RERECORD_EPISODE: rerecord_episode,
+ }
+
def disconnect(self) -> None:
"""Disconnect from the gamepad."""
if self.gamepad is not None:
diff --git a/src/lerobot/teleoperators/homunculus/__init__.py b/src/lerobot/teleoperators/homunculus/__init__.py
new file mode 100644
index 000000000..b3c6c0bf5
--- /dev/null
+++ b/src/lerobot/teleoperators/homunculus/__init__.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig
+from .homunculus_arm import HomunculusArm
+from .homunculus_glove import HomunculusGlove
+from .joints_translation import homunculus_glove_to_hope_jr_hand
diff --git a/src/lerobot/teleoperators/homunculus/config_homunculus.py b/src/lerobot/teleoperators/homunculus/config_homunculus.py
new file mode 100644
index 000000000..da465215a
--- /dev/null
+++ b/src/lerobot/teleoperators/homunculus/config_homunculus.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+from ..config import TeleoperatorConfig
+
+
+@TeleoperatorConfig.register_subclass("homunculus_glove")
+@dataclass
+class HomunculusGloveConfig(TeleoperatorConfig):
+ port: str # Port to connect to the glove
+ side: str # "left" / "right"
+ baud_rate: int = 115_200
+
+ def __post_init__(self):
+ if self.side not in ["right", "left"]:
+ raise ValueError(self.side)
+
+
+@TeleoperatorConfig.register_subclass("homunculus_arm")
+@dataclass
+class HomunculusArmConfig(TeleoperatorConfig):
+ port: str # Port to connect to the arm
+ baud_rate: int = 115_200
diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py
new file mode 100644
index 000000000..21d73de2e
--- /dev/null
+++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py
@@ -0,0 +1,309 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import threading
+from collections import deque
+from pprint import pformat
+
+import serial
+
+from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.utils import enter_pressed, move_cursor_up
+
+from ..teleoperator import Teleoperator
+from .config_homunculus import HomunculusArmConfig
+
+logger = logging.getLogger(__name__)
+
+
+class HomunculusArm(Teleoperator):
+ """
+ Homunculus Arm designed by Hugging Face.
+ """
+
+ config_class = HomunculusArmConfig
+ name = "homunculus_arm"
+
+ def __init__(self, config: HomunculusArmConfig):
+ super().__init__(config)
+ self.config = config
+ self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
+ self.serial_lock = threading.Lock()
+
+ self.joints = {
+ "shoulder_pitch": MotorNormMode.RANGE_M100_100,
+ "shoulder_yaw": MotorNormMode.RANGE_M100_100,
+ "shoulder_roll": MotorNormMode.RANGE_M100_100,
+ "elbow_flex": MotorNormMode.RANGE_M100_100,
+ "wrist_roll": MotorNormMode.RANGE_M100_100,
+ "wrist_yaw": MotorNormMode.RANGE_M100_100,
+ "wrist_pitch": MotorNormMode.RANGE_M100_100,
+ }
+ n = 50
+ # EMA parameters ---------------------------------------------------
+ self.n: int = n
+ self.alpha: float = 2 / (n + 1)
+ # one deque *per joint* so we can inspect raw history if needed
+ self._buffers: dict[str, deque[int]] = {
+ joint: deque(maxlen=n)
+ for joint in (
+ "shoulder_pitch",
+ "shoulder_yaw",
+ "shoulder_roll",
+ "elbow_flex",
+ "wrist_roll",
+ "wrist_yaw",
+ "wrist_pitch",
+ )
+ }
+ # running EMA value per joint – lazily initialised on first read
+ self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)
+
+ self._state: dict[str, float] | None = None
+ self.new_state_event = threading.Event()
+ self.stop_event = threading.Event()
+ self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
+ self.state_lock = threading.Lock()
+
+ @property
+ def action_features(self) -> dict:
+ return {f"{joint}.pos": float for joint in self.joints}
+
+ @property
+ def feedback_features(self) -> dict:
+ return {}
+
+ @property
+ def is_connected(self) -> bool:
+ with self.serial_lock:
+ return self.serial.is_open and self.thread.is_alive()
+
+ def connect(self, calibrate: bool = True) -> None:
+ if self.is_connected:
+ raise DeviceAlreadyConnectedError(f"{self} already connected")
+
+ if not self.serial.is_open:
+ self.serial.open()
+ self.thread.start()
+
+ # wait for the thread to ramp up & 1st state to be ready
+ if not self.new_state_event.wait(timeout=2):
+ raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
+
+ if not self.is_calibrated and calibrate:
+ self.calibrate()
+
+ logger.info(f"{self} connected.")
+
+ @property
+ def is_calibrated(self) -> bool:
+ return self.calibration_fpath.is_file()
+
+ def calibrate(self) -> None:
+ print(
+ "\nMove all joints through their entire range of motion."
+ "\nRecording positions. Press ENTER to stop..."
+ )
+ range_mins, range_maxes = self._record_ranges_of_motion()
+
+ self.calibration = {}
+ for id_, joint in enumerate(self.joints):
+ self.calibration[joint] = MotorCalibration(
+ id=id_,
+ drive_mode=0,
+ homing_offset=0,
+ range_min=range_mins[joint],
+ range_max=range_maxes[joint],
+ )
+
+ self._save_calibration()
+ print("Calibration saved to", self.calibration_fpath)
+
+ # TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code.
+ def _record_ranges_of_motion(
+ self, joints: list[str] | None = None, display_values: bool = True
+ ) -> tuple[dict[str, int], dict[str, int]]:
+ """Interactively record the min/max encoder values of each joint.
+
+ Move the joints while the method streams live positions. Press :kbd:`Enter` to finish.
+
+ Args:
+ joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`).
+ display_values (bool, optional): When `True` (default) a live table is printed to the console.
+
+ Raises:
+ TypeError: `joints` is not `None` or a list.
+ ValueError: any joint's recorded min and max are the same.
+
+ Returns:
+ tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values
+ observed for each joint.
+ """
+ if joints is None:
+ joints = list(self.joints)
+ elif not isinstance(joints, list):
+ raise TypeError(joints)
+
+ display_len = max(len(key) for key in joints)
+
+ start_positions = self._read(joints, normalize=False)
+ mins = start_positions.copy()
+ maxes = start_positions.copy()
+
+ user_pressed_enter = False
+ while not user_pressed_enter:
+ positions = self._read(joints, normalize=False)
+ mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()}
+ maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()}
+
+ if display_values:
+ print("\n-------------------------------------------")
+ print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
+ for joint in joints:
+ print(
+ f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}"
+ )
+
+ if enter_pressed():
+ user_pressed_enter = True
+
+ if display_values and not user_pressed_enter:
+ # Move cursor up to overwrite the previous output
+ move_cursor_up(len(joints) + 3)
+
+ same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]]
+ if same_min_max:
+ raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}")
+
+ return mins, maxes
+
+ def configure(self) -> None:
+ pass
+
+ # TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code.
+ def _normalize(self, values: dict[str, int]) -> dict[str, float]:
+ if not self.calibration:
+ raise RuntimeError(f"{self} has no calibration registered.")
+
+ normalized_values = {}
+ for joint, val in values.items():
+ min_ = self.calibration[joint].range_min
+ max_ = self.calibration[joint].range_max
+ drive_mode = self.calibration[joint].drive_mode
+ bounded_val = min(max_, max(min_, val))
+
+ if self.joints[joint] is MotorNormMode.RANGE_M100_100:
+ norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
+ normalized_values[joint] = -norm if drive_mode else norm
+ elif self.joints[joint] is MotorNormMode.RANGE_0_100:
+ norm = ((bounded_val - min_) / (max_ - min_)) * 100
+ normalized_values[joint] = 100 - norm if drive_mode else norm
+
+ return normalized_values
+
+ def _apply_ema(self, raw: dict[str, int]) -> dict[str, float]:
+ """Update buffers & running EMA values; return smoothed dict."""
+ smoothed: dict[str, float] = {}
+ for joint, value in raw.items():
+ # maintain raw history
+ self._buffers[joint].append(value)
+
+ # initialise on first run
+ if self._ema[joint] is None:
+ self._ema[joint] = float(value)
+ else:
+ self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint]
+
+ smoothed[joint] = self._ema[joint]
+ return smoothed
+
+ def _read(
+ self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
+ ) -> dict[str, int | float]:
+ """
+ Return the most recent (single) values from self.last_d,
+ optionally applying calibration.
+ """
+ if not self.new_state_event.wait(timeout=timeout):
+ raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
+
+ with self.state_lock:
+ state = self._state
+
+ self.new_state_event.clear()
+
+ if state is None:
+ raise RuntimeError(f"{self} Internal error: Event set but no state available.")
+
+ if joints is not None:
+ state = {k: v for k, v in state.items() if k in joints}
+
+ if normalize:
+ state = self._normalize(state)
+
+ state = self._apply_ema(state)
+
+ return state
+
+ def _read_loop(self):
+ """
+ Continuously read from the serial buffer in its own thread and sends values to the main thread through
+ a queue.
+ """
+ while not self.stop_event.is_set():
+ try:
+ raw_values = None
+ with self.serial_lock:
+ if self.serial.in_waiting > 0:
+ self.serial.flush()
+ raw_values = self.serial.readline().decode("utf-8").strip().split(" ")
+ if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values
+ continue
+
+ joint_angles = {
+ "shoulder_pitch": int(raw_values[19]),
+ "shoulder_yaw": int(raw_values[18]),
+ "shoulder_roll": int(raw_values[20]),
+ "elbow_flex": int(raw_values[17]),
+ "wrist_roll": int(raw_values[16]),
+ "wrist_yaw": int(raw_values[1]),
+ "wrist_pitch": int(raw_values[0]),
+ }
+
+ with self.state_lock:
+ self._state = joint_angles
+ self.new_state_event.set()
+
+ except Exception as e:
+ logger.debug(f"Error reading frame in background thread for {self}: {e}")
+
+ def get_action(self) -> dict[str, float]:
+ joint_positions = self._read()
+ return {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
+
+ def send_feedback(self, feedback: dict[str, float]) -> None:
+ raise NotImplementedError
+
+ def disconnect(self) -> None:
+ if not self.is_connected:
+ DeviceNotConnectedError(f"{self} is not connected.")
+
+ self.stop_event.set()
+ self.thread.join(timeout=1)
+ self.serial.close()
+ logger.info(f"{self} disconnected.")
diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py
new file mode 100644
index 000000000..251ecf56d
--- /dev/null
+++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py
@@ -0,0 +1,337 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import threading
+from collections import deque
+from pprint import pformat
+
+import serial
+
+from lerobot.motors import MotorCalibration
+from lerobot.motors.motors_bus import MotorNormMode
+from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.utils import enter_pressed, move_cursor_up
+
+from ..teleoperator import Teleoperator
+from .config_homunculus import HomunculusGloveConfig
+
+logger = logging.getLogger(__name__)
+
+LEFT_HAND_INVERSIONS = [
+ "thumb_cmc",
+ "index_dip",
+ "middle_mcp_abduction",
+ "middle_dip",
+ "pinky_mcp_abduction",
+ "pinky_dip",
+]
+
+RIGHT_HAND_INVERSIONS = [
+ "thumb_mcp",
+ "thumb_cmc",
+ "thumb_pip",
+ "thumb_dip",
+ "index_mcp_abduction",
+ # "index_dip",
+ "middle_mcp_abduction",
+ # "middle_dip",
+ "ring_mcp_abduction",
+ "ring_mcp_flexion",
+ # "ring_dip",
+ "pinky_mcp_abduction",
+]
+
+
+class HomunculusGlove(Teleoperator):
+ """
+ Homunculus Glove designed by NepYope & Hugging Face.
+ """
+
+ config_class = HomunculusGloveConfig
+ name = "homunculus_glove"
+
+ def __init__(self, config: HomunculusGloveConfig):
+ super().__init__(config)
+ self.config = config
+ self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
+ self.serial_lock = threading.Lock()
+
+ self.joints = {
+ "thumb_cmc": MotorNormMode.RANGE_0_100,
+ "thumb_mcp": MotorNormMode.RANGE_0_100,
+ "thumb_pip": MotorNormMode.RANGE_0_100,
+ "thumb_dip": MotorNormMode.RANGE_0_100,
+ "index_mcp_abduction": MotorNormMode.RANGE_M100_100,
+ "index_mcp_flexion": MotorNormMode.RANGE_0_100,
+ "index_dip": MotorNormMode.RANGE_0_100,
+ "middle_mcp_abduction": MotorNormMode.RANGE_M100_100,
+ "middle_mcp_flexion": MotorNormMode.RANGE_0_100,
+ "middle_dip": MotorNormMode.RANGE_0_100,
+ "ring_mcp_abduction": MotorNormMode.RANGE_M100_100,
+ "ring_mcp_flexion": MotorNormMode.RANGE_0_100,
+ "ring_dip": MotorNormMode.RANGE_0_100,
+ "pinky_mcp_abduction": MotorNormMode.RANGE_M100_100,
+ "pinky_mcp_flexion": MotorNormMode.RANGE_0_100,
+ "pinky_dip": MotorNormMode.RANGE_0_100,
+ }
+ self.inverted_joints = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS
+
+ n = 10
+ # EMA parameters ---------------------------------------------------
+ self.n: int = n
+ self.alpha: float = 2 / (n + 1)
+ # one deque *per joint* so we can inspect raw history if needed
+ self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
+ # running EMA value per joint – lazily initialised on first read
+ self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)
+
+ self._state: dict[str, float] | None = None
+ self.new_state_event = threading.Event()
+ self.stop_event = threading.Event()
+ self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
+ self.state_lock = threading.Lock()
+
+ @property
+ def action_features(self) -> dict:
+ return {f"{joint}.pos": float for joint in self.joints}
+
+ @property
+ def feedback_features(self) -> dict:
+ return {}
+
+ @property
+ def is_connected(self) -> bool:
+ with self.serial_lock:
+ return self.serial.is_open and self.thread.is_alive()
+
+ def connect(self, calibrate: bool = True) -> None:
+ if self.is_connected:
+ raise DeviceAlreadyConnectedError(f"{self} already connected")
+
+ if not self.serial.is_open:
+ self.serial.open()
+ self.thread.start()
+
+ # wait for the thread to ramp up & 1st state to be ready
+ if not self.new_state_event.wait(timeout=2):
+ raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
+
+ if not self.is_calibrated and calibrate:
+ self.calibrate()
+
+ logger.info(f"{self} connected.")
+
+ @property
+ def is_calibrated(self) -> bool:
+ return self.calibration_fpath.is_file()
+
+ def calibrate(self) -> None:
+ range_mins, range_maxes = {}, {}
+ for finger in ["thumb", "index", "middle", "ring", "pinky"]:
+ print(
+ f"\nMove {finger} through its entire range of motion."
+ "\nRecording positions. Press ENTER to stop..."
+ )
+ finger_joints = [joint for joint in self.joints if joint.startswith(finger)]
+ finger_mins, finger_maxes = self._record_ranges_of_motion(finger_joints)
+ range_mins.update(finger_mins)
+ range_maxes.update(finger_maxes)
+
+ self.calibration = {}
+ for id_, joint in enumerate(self.joints):
+ self.calibration[joint] = MotorCalibration(
+ id=id_,
+ drive_mode=1 if joint in self.inverted_joints else 0,
+ homing_offset=0,
+ range_min=range_mins[joint],
+ range_max=range_maxes[joint],
+ )
+
+ self._save_calibration()
+ print("Calibration saved to", self.calibration_fpath)
+
+ # TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code.
+ def _record_ranges_of_motion(
+ self, joints: list[str] | None = None, display_values: bool = True
+ ) -> tuple[dict[str, int], dict[str, int]]:
+ """Interactively record the min/max encoder values of each joint.
+
+ Move the joints while the method streams live positions. Press :kbd:`Enter` to finish.
+
+ Args:
+ joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`).
+ display_values (bool, optional): When `True` (default) a live table is printed to the console.
+
+ Raises:
+ TypeError: `joints` is not `None` or a list.
+ ValueError: any joint's recorded min and max are the same.
+
+ Returns:
+ tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values
+ observed for each joint.
+ """
+ if joints is None:
+ joints = list(self.joints)
+ elif not isinstance(joints, list):
+ raise TypeError(joints)
+
+ display_len = max(len(key) for key in joints)
+
+ start_positions = self._read(joints, normalize=False)
+ mins = start_positions.copy()
+ maxes = start_positions.copy()
+
+ user_pressed_enter = False
+ while not user_pressed_enter:
+ positions = self._read(joints, normalize=False)
+ mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()}
+ maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()}
+
+ if display_values:
+ print("\n-------------------------------------------")
+ print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
+ for joint in joints:
+ print(
+ f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}"
+ )
+
+ if enter_pressed():
+ user_pressed_enter = True
+
+ if display_values and not user_pressed_enter:
+ # Move cursor up to overwrite the previous output
+ move_cursor_up(len(joints) + 3)
+
+ same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]]
+ if same_min_max:
+ raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}")
+
+ return mins, maxes
+
+ def configure(self) -> None:
+ pass
+
+ # TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code.
+ def _normalize(self, values: dict[str, int]) -> dict[str, float]:
+ if not self.calibration:
+ raise RuntimeError(f"{self} has no calibration registered.")
+
+ normalized_values = {}
+ for joint, val in values.items():
+ min_ = self.calibration[joint].range_min
+ max_ = self.calibration[joint].range_max
+ drive_mode = self.calibration[joint].drive_mode
+ bounded_val = min(max_, max(min_, val))
+
+ if self.joints[joint] is MotorNormMode.RANGE_M100_100:
+ norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
+ normalized_values[joint] = -norm if drive_mode else norm
+ elif self.joints[joint] is MotorNormMode.RANGE_0_100:
+ norm = ((bounded_val - min_) / (max_ - min_)) * 100
+ normalized_values[joint] = 100 - norm if drive_mode else norm
+
+ return normalized_values
+
+ def _apply_ema(self, raw: dict[str, int]) -> dict[str, int]:
+ """Update buffers & running EMA values; return smoothed dict as integers."""
+ smoothed: dict[str, int] = {}
+ for joint, value in raw.items():
+ # maintain raw history
+ self._buffers[joint].append(value)
+
+ # initialise on first run
+ if self._ema[joint] is None:
+ self._ema[joint] = float(value)
+ else:
+ self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint]
+
+ # Convert back to int for compatibility with normalization
+ smoothed[joint] = int(round(self._ema[joint]))
+ return smoothed
+
+ def _read(
+ self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
+ ) -> dict[str, int | float]:
+ """
+ Return the most recent (single) values from self.last_d,
+ optionally applying calibration.
+ """
+ if not self.new_state_event.wait(timeout=timeout):
+ raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
+
+ with self.state_lock:
+ state = self._state
+
+ self.new_state_event.clear()
+
+ if state is None:
+ raise RuntimeError(f"{self} Internal error: Event set but no state available.")
+
+ if joints is not None:
+ state = {k: v for k, v in state.items() if k in joints}
+
+ # Apply EMA smoothing to raw values first
+ state = self._apply_ema(state)
+
+ # Then normalize if requested
+ if normalize:
+ state = self._normalize(state)
+
+ return state
+
+ def _read_loop(self):
+ """
+ Continuously read from the serial buffer in its own thread and sends values to the main thread through
+ a queue.
+ """
+ while not self.stop_event.is_set():
+ try:
+ positions = None
+ with self.serial_lock:
+ if self.serial.in_waiting > 0:
+ self.serial.flush()
+ positions = self.serial.readline().decode("utf-8").strip().split(" ")
+ if positions is None or len(positions) != len(self.joints):
+ continue
+
+ joint_positions = {joint: int(pos) for joint, pos in zip(self.joints, positions, strict=True)}
+
+ with self.state_lock:
+ self._state = joint_positions
+ self.new_state_event.set()
+
+ except Exception as e:
+ logger.debug(f"Error reading frame in background thread for {self}: {e}")
+
+ def get_action(self) -> dict[str, float]:
+ joint_positions = self._read()
+ return homunculus_glove_to_hope_jr_hand(
+ {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
+ )
+
+ def send_feedback(self, feedback: dict[str, float]) -> None:
+ raise NotImplementedError
+
+ def disconnect(self) -> None:
+ if not self.is_connected:
+ DeviceNotConnectedError(f"{self} is not connected.")
+
+ self.stop_event.set()
+ self.thread.join(timeout=1)
+ self.serial.close()
+ logger.info(f"{self} disconnected.")
diff --git a/src/lerobot/teleoperators/homunculus/joints_translation.py b/src/lerobot/teleoperators/homunculus/joints_translation.py
new file mode 100644
index 000000000..f14f7b3ef
--- /dev/null
+++ b/src/lerobot/teleoperators/homunculus/joints_translation.py
@@ -0,0 +1,63 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+INDEX_SPLAY = 0.3
+MIDDLE_SPLAY = 0.3
+RING_SPLAY = 0.3
+PINKY_SPLAY = 0.5
+
+
+def get_ulnar_flexion(flexion: float, abduction: float, splay: float):
+ return -abduction * splay + flexion * (1 - splay)
+
+
+def get_radial_flexion(flexion: float, abduction: float, splay: float):
+ return abduction * splay + flexion * (1 - splay)
+
+
+def homunculus_glove_to_hope_jr_hand(glove_action: dict[str, float]) -> dict[str, float]:
+ return {
+ "thumb_cmc.pos": glove_action["thumb_cmc.pos"],
+ "thumb_mcp.pos": glove_action["thumb_mcp.pos"],
+ "thumb_pip.pos": glove_action["thumb_pip.pos"],
+ "thumb_dip.pos": glove_action["thumb_dip.pos"],
+ "index_radial_flexor.pos": get_radial_flexion(
+ glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY
+ ),
+ "index_ulnar_flexor.pos": get_ulnar_flexion(
+ glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY
+ ),
+ "index_pip_dip.pos": glove_action["index_dip.pos"],
+ "middle_radial_flexor.pos": get_radial_flexion(
+ glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY
+ ),
+ "middle_ulnar_flexor.pos": get_ulnar_flexion(
+ glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY
+ ),
+ "middle_pip_dip.pos": glove_action["middle_dip.pos"],
+ "ring_radial_flexor.pos": get_radial_flexion(
+ glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY
+ ),
+ "ring_ulnar_flexor.pos": get_ulnar_flexion(
+ glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY
+ ),
+ "ring_pip_dip.pos": glove_action["ring_dip.pos"],
+ "pinky_radial_flexor.pos": get_radial_flexion(
+ glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY
+ ),
+ "pinky_ulnar_flexor.pos": get_ulnar_flexion(
+ glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY
+ ),
+ "pinky_pip_dip.pos": glove_action["pinky_dip.pos"],
+ }
diff --git a/src/lerobot/teleoperators/keyboard/__init__.py b/src/lerobot/teleoperators/keyboard/__init__.py
index 5761bf788..72d01003a 100644
--- a/src/lerobot/teleoperators/keyboard/__init__.py
+++ b/src/lerobot/teleoperators/keyboard/__init__.py
@@ -1,3 +1,19 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
from .teleop_keyboard import KeyboardEndEffectorTeleop, KeyboardTeleop
diff --git a/src/lerobot/teleoperators/keyboard/configuration_keyboard.py b/src/lerobot/teleoperators/keyboard/configuration_keyboard.py
index 5d5ef364f..6e070dedd 100644
--- a/src/lerobot/teleoperators/keyboard/configuration_keyboard.py
+++ b/src/lerobot/teleoperators/keyboard/configuration_keyboard.py
@@ -22,8 +22,9 @@ from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("keyboard")
@dataclass
class KeyboardTeleopConfig(TeleoperatorConfig):
+ """KeyboardTeleopConfig"""
+
# TODO(Steven): Consider setting in here the keys that we want to capture/listen
- mock: bool = False
@TeleoperatorConfig.register_subclass("keyboard_ee")
diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py
index d034982f1..6f53a17c7 100644
--- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py
+++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py
@@ -21,9 +21,10 @@ import time
from queue import Queue
from typing import Any
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
+from ..utils import TeleopEvents
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
PYNPUT_AVAILABLE = True
@@ -176,16 +177,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
}
- def _on_press(self, key):
- if hasattr(key, "char"):
- key = key.char
- self.event_queue.put((key, True))
-
- def _on_release(self, key):
- if hasattr(key, "char"):
- key = key.char
- self.event_queue.put((key, False))
-
def get_action(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(
@@ -235,3 +226,66 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
action_dict["gripper"] = gripper_action
return action_dict
+
+ def get_teleop_events(self) -> dict[str, Any]:
+ """
+ Get extra control events from the keyboard such as intervention status,
+ episode termination, success indicators, etc.
+
+ Keyboard mappings:
+ - Any movement keys pressed = intervention active
+ - 's' key = success (terminate episode successfully)
+ - 'r' key = rerecord episode (terminate and rerecord)
+ - 'q' key = quit episode (terminate without success)
+
+ Returns:
+ Dictionary containing:
+ - is_intervention: bool - Whether human is currently intervening
+ - terminate_episode: bool - Whether to terminate the current episode
+ - success: bool - Whether the episode was successful
+ - rerecord_episode: bool - Whether to rerecord the episode
+ """
+ if not self.is_connected:
+ return {
+ TeleopEvents.IS_INTERVENTION: False,
+ TeleopEvents.TERMINATE_EPISODE: False,
+ TeleopEvents.SUCCESS: False,
+ TeleopEvents.RERECORD_EPISODE: False,
+ }
+
+ # Check if any movement keys are currently pressed (indicates intervention)
+ movement_keys = [
+ keyboard.Key.up,
+ keyboard.Key.down,
+ keyboard.Key.left,
+ keyboard.Key.right,
+ keyboard.Key.shift,
+ keyboard.Key.shift_r,
+ keyboard.Key.ctrl_r,
+ keyboard.Key.ctrl_l,
+ ]
+ is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
+
+ # Check for episode control commands from misc_keys_queue
+ terminate_episode = False
+ success = False
+ rerecord_episode = False
+
+ # Process any pending misc keys
+ while not self.misc_keys_queue.empty():
+ key = self.misc_keys_queue.get_nowait()
+ if key == "s":
+ success = True
+ elif key == "r":
+ terminate_episode = True
+ rerecord_episode = True
+ elif key == "q":
+ terminate_episode = True
+ success = False
+
+ return {
+ TeleopEvents.IS_INTERVENTION: is_intervention,
+ TeleopEvents.TERMINATE_EPISODE: terminate_episode,
+ TeleopEvents.SUCCESS: success,
+ TeleopEvents.RERECORD_EPISODE: rerecord_episode,
+ }
diff --git a/src/lerobot/teleoperators/koch_leader/__init__.py b/src/lerobot/teleoperators/koch_leader/__init__.py
index ad2d6a0e4..1bf9d51db 100644
--- a/src/lerobot/teleoperators/koch_leader/__init__.py
+++ b/src/lerobot/teleoperators/koch_leader/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_koch_leader import KochLeaderConfig
from .koch_leader import KochLeader
diff --git a/src/lerobot/teleoperators/koch_leader/koch_leader.py b/src/lerobot/teleoperators/koch_leader/koch_leader.py
index 8eb076fae..0409f2e57 100644
--- a/src/lerobot/teleoperators/koch_leader/koch_leader.py
+++ b/src/lerobot/teleoperators/koch_leader/koch_leader.py
@@ -17,13 +17,13 @@
import logging
import time
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_koch_leader import KochLeaderConfig
@@ -75,6 +75,9 @@ class KochLeader(Teleoperator):
self.bus.connect()
if not self.is_calibrated and calibrate:
+ logger.info(
+ "Mismatch between calibration values in the motor and the calibration file or no calibration file found"
+ )
self.calibrate()
self.configure()
@@ -85,8 +88,17 @@ class KochLeader(Teleoperator):
return self.bus.is_calibrated
def calibrate(self) -> None:
- logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
+ if self.calibration:
+ # Calibration file exists, ask user whether to use it or run new calibration
+ user_input = input(
+ f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
+ )
+ if user_input.strip().lower() != "c":
+ logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
+ self.bus.write_calibration(self.calibration)
+ return
+ logger.info(f"\nRunning calibration of {self}")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
diff --git a/src/lerobot/teleoperators/phone/__init__.py b/src/lerobot/teleoperators/phone/__init__.py
new file mode 100644
index 000000000..2b28c1f97
--- /dev/null
+++ b/src/lerobot/teleoperators/phone/__init__.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .config_phone import PhoneConfig
+from .teleop_phone import Phone
diff --git a/src/lerobot/teleoperators/phone/config_phone.py b/src/lerobot/teleoperators/phone/config_phone.py
new file mode 100644
index 000000000..380d5f5ff
--- /dev/null
+++ b/src/lerobot/teleoperators/phone/config_phone.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from enum import Enum
+
+import numpy as np
+
+from ..config import TeleoperatorConfig
+
+
+class PhoneOS(Enum):
+ ANDROID = "android"
+ IOS = "ios"
+
+
+@TeleoperatorConfig.register_subclass("phone")
+@dataclass
+class PhoneConfig(TeleoperatorConfig):
+ phone_os: PhoneOS = PhoneOS.IOS
+ camera_offset = np.array(
+ [0.0, -0.02, 0.04]
+ ) # iPhone 14 Pro camera is 2cm off center and 4cm above center
diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py
new file mode 100644
index 000000000..67e64c7d5
--- /dev/null
+++ b/src/lerobot/teleoperators/phone/phone_processor.py
@@ -0,0 +1,110 @@
+# !/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass, field
+
+from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.processor import ProcessorStepRegistry, RobotAction, RobotActionProcessorStep
+from lerobot.teleoperators.phone.config_phone import PhoneOS
+
+
+@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
+@dataclass
+class MapPhoneActionToRobotAction(RobotActionProcessorStep):
+ """
+ Maps calibrated phone pose actions to standardized robot action inputs.
+
+ This processor step acts as a bridge between the phone teleoperator's output
+ and the robot's expected action format. It remaps the phone's 6-DoF pose
+ (position and rotation) to the robot's target end-effector pose, applying
+ necessary axis inversions and swaps. It also interprets platform-specific
+ button presses to generate a gripper command.
+
+ Attributes:
+ platform: The operating system of the phone (iOS or Android), used
+ to determine the correct button mappings for the gripper.
+ """
+
+ # TODO(Steven): Gripper vel could be output of phone_teleop directly
+ platform: PhoneOS
+ _enabled_prev: bool = field(default=False, init=False, repr=False)
+
+ def action(self, action: RobotAction) -> RobotAction:
+ """
+ Processes the phone action dictionary to create a robot action dictionary.
+
+ Args:
+ act: The input action dictionary from the phone teleoperator.
+
+ Returns:
+ A new action dictionary formatted for the robot controller.
+
+ Raises:
+ ValueError: If 'pos' or 'rot' keys are missing from the input action.
+ """
+ # Pop them from the action
+ enabled = bool(action.pop("phone.enabled"))
+ pos = action.pop("phone.pos")
+ rot = action.pop("phone.rot")
+ inputs = action.pop("phone.raw_inputs")
+
+ if pos is None or rot is None:
+ raise ValueError("pos and rot must be present in action")
+
+ rotvec = rot.as_rotvec() # Absolute orientation as rotvec
+
+ # Map certain inputs to certain actions
+ if self.platform == PhoneOS.IOS:
+ gripper_vel = float(inputs.get("a3", 0.0))
+ else:
+ a = float(inputs.get("reservedButtonA", 0.0))
+ b = float(inputs.get("reservedButtonB", 0.0))
+ gripper_vel = (
+ a - b
+ ) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
+
+ # For some actions we need to invert the axis
+ action["enabled"] = enabled
+ action["target_x"] = -pos[1] if enabled else 0.0
+ action["target_y"] = pos[0] if enabled else 0.0
+ action["target_z"] = pos[2] if enabled else 0.0
+ action["target_wx"] = rotvec[1] if enabled else 0.0
+ action["target_wy"] = rotvec[0] if enabled else 0.0
+ action["target_wz"] = -rotvec[2] if enabled else 0.0
+ action["gripper_vel"] = gripper_vel # Still send gripper action when disabled
+ return action
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ for feat in ["enabled", "pos", "rot", "raw_inputs"]:
+ features[PipelineFeatureType.ACTION].pop(f"phone.{feat}", None)
+
+ for feat in [
+ "enabled",
+ "target_x",
+ "target_y",
+ "target_z",
+ "target_wx",
+ "target_wy",
+ "target_wz",
+ "gripper_vel",
+ ]:
+ features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
+ type=FeatureType.ACTION, shape=(1,)
+ )
+
+ return features
diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py
new file mode 100644
index 000000000..91e613190
--- /dev/null
+++ b/src/lerobot/teleoperators/phone/teleop_phone.py
@@ -0,0 +1,421 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Docs:
+# hebi: https://docs.hebi.us/tools.html#mobile-io
+# teleop: https://github.com/SpesRobotics/teleop
+
+import logging
+import threading
+import time
+
+import hebi
+import numpy as np
+from teleop import Teleop
+
+from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
+from lerobot.teleoperators.teleoperator import Teleoperator
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from lerobot.utils.rotation import Rotation
+
+logger = logging.getLogger(__name__)
+
+
+class BasePhone:
+ _enabled: bool = False
+ _calib_pos: np.ndarray | None = None
+ _calib_rot_inv: Rotation | None = None
+
+ def _reapply_position_calibration(self, pos: np.ndarray) -> None:
+ self._calib_pos = pos.copy()
+
+ @property
+ def is_calibrated(self) -> bool:
+ return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
+
+ @property
+ def action_features(self) -> dict[str, type]:
+ return {
+ "phone.pos": np.ndarray, # shape (3,)
+ "phone.rot": Rotation, # scipy.spatial.transform.Rotation
+ "phone.raw_inputs": dict, # analogs/buttons or webXR meta
+ "phone.enabled": bool,
+ }
+
+ @property
+ def feedback_features(self) -> dict[str, type]:
+ # No haptic or other feedback implemented yet
+ pass
+
+ def configure(self) -> None:
+ # No additional configuration required for phone teleop
+ pass
+
+ def send_feedback(self, feedback: dict[str, float]) -> None:
+ # We could add haptic feedback (vibrations) here, but it's not implemented yet
+ raise NotImplementedError
+
+
+class IOSPhone(BasePhone, Teleoperator):
+ name = "ios_phone"
+
+ def __init__(self, config: PhoneConfig):
+ super().__init__(config)
+ self.config = config
+ self._group = None
+
+ @property
+ def is_connected(self) -> bool:
+ return self._group is not None
+
+ def connect(self) -> None:
+ if self.is_connected:
+ raise DeviceAlreadyConnectedError(f"{self} already connected")
+
+ logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
+ lookup = hebi.Lookup()
+ time.sleep(2.0)
+ group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
+ if group is None:
+ raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
+ self._group = group
+ logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
+
+ self.calibrate()
+
+ def calibrate(self) -> None:
+ print(
+ "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
+ )
+ print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
+ position, rotation = self._wait_for_capture_trigger()
+ self._calib_pos = position.copy()
+ self._calib_rot_inv = rotation.inv()
+ self._enabled = False
+ print("Calibration done\n")
+
+ def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
+ """
+ Blocks execution until the calibration trigger is detected from the iOS device.
+
+ This method enters a loop, continuously reading the phone's state. It waits for the user to press
+ and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and
+ returns the phone's pose at that exact moment.
+
+ Returns:
+ A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
+ moment the trigger was activated.
+ """
+ while True:
+ has_pose, position, rotation, fb_pose = self._read_current_pose()
+ if not has_pose:
+ time.sleep(0.01)
+ continue
+
+ io = getattr(fb_pose, "io", None)
+ button_b = getattr(io, "b", None) if io is not None else None
+ button_b1_pressed = False
+ if button_b is not None:
+ button_b1_pressed = bool(button_b.get_int(1))
+ if button_b1_pressed:
+ return position, rotation
+
+ time.sleep(0.01)
+
+ def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
+ """
+ Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK.
+
+ This method fetches the latest feedback packet from the HEBI group, extracts the ARKit
+ position and orientation, and converts them into a standard format. It also applies a
+ configured camera offset to adjust the pose from the camera's frame to the phone's
+ physical frame.
+
+ Returns:
+ A tuple containing:
+ - A boolean indicating if a valid pose was successfully read.
+ - The 3D position as a NumPy array, or None if not available.
+ - The orientation as a `Rotation` object, or None if not available.
+ - The raw HEBI feedback object for accessing other data like button presses.
+ """
+ fbk = self._group.get_next_feedback()
+ pose = fbk[0]
+ ar_pos = getattr(pose, "ar_position", None)
+ ar_quat = getattr(pose, "ar_orientation", None)
+ if ar_pos is None or ar_quat is None:
+ return False, None, None, None
+ # HEBI provides orientation in w, x, y, z format.
+ # Scipy's Rotation expects x, y, z, w.
+ quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
+ rot = Rotation.from_quat(quat_xyzw)
+ pos = ar_pos - rot.apply(self.config.camera_offset)
+ return True, pos, rot, pose
+
+ def get_action(self) -> dict:
+ has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
+ if not has_pose or not self.is_calibrated:
+ return {}
+
+ # Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
+ raw_inputs: dict[str, float | int | bool] = {}
+ io = getattr(fb_pose, "io", None)
+ if io is not None:
+ bank_a, bank_b = io.a, io.b
+ if bank_a:
+ for ch in range(1, 9):
+ if bank_a.has_float(ch):
+ raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
+ if bank_b:
+ for ch in range(1, 9):
+ if bank_b.has_int(ch):
+ raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
+ elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
+ raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
+
+ enable = bool(raw_inputs.get("b1", 0))
+
+ # Rising edge then re-capture calibration immediately from current raw pose
+ if enable and not self._enabled:
+ self._reapply_position_calibration(raw_position)
+
+ # Apply calibration
+ pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos)
+ rot_cal = self._calib_rot_inv * raw_rotation
+
+ self._enabled = enable
+
+ return {
+ "phone.pos": pos_cal,
+ "phone.rot": rot_cal,
+ "phone.raw_inputs": raw_inputs,
+ "phone.enabled": self._enabled,
+ }
+
+ def disconnect(self) -> None:
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ self._group = None
+
+
+class AndroidPhone(BasePhone, Teleoperator):
+ name = "android_phone"
+
+ def __init__(self, config: PhoneConfig):
+ super().__init__(config)
+ self.config = config
+ self._teleop = None
+ self._teleop_thread = None
+ self._latest_pose = None
+ self._latest_message = None
+ self._android_lock = threading.Lock()
+
+ @property
+ def is_connected(self) -> bool:
+ return self._teleop is not None
+
+ def connect(self) -> None:
+ if self.is_connected:
+ raise DeviceAlreadyConnectedError(f"{self} already connected")
+
+ logger.info("Starting teleop stream for Android...")
+ self._teleop = Teleop()
+ self._teleop.subscribe(self._android_callback)
+ self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
+ self._teleop_thread.start()
+ logger.info(f"{self} connected, teleop stream started.")
+
+ self.calibrate()
+
+ def calibrate(self) -> None:
+ print(
+ "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
+ )
+ print("Touch and move on the WebXR page to capture this pose...\n")
+
+ pos, rot = self._wait_for_capture_trigger()
+ self._calib_pos = pos.copy()
+ self._calib_rot_inv = rot.inv()
+ self._enabled = False
+ print("Calibration done\n")
+
+ def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
+ """
+ Blocks execution until the calibration trigger is detected from the Android device.
+
+ This method enters a loop, continuously checking the latest message received from the WebXR
+ session. It waits for the user to touch and move their finger on the screen, which generates
+ a `move` event. Once this event is detected, the loop breaks and returns the phone's current
+ pose.
+
+ Returns:
+ A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
+ moment the trigger was activated.
+ """
+ while True:
+ with self._android_lock:
+ msg = self._latest_message or {}
+
+ if bool(msg.get("move", False)):
+ ok, pos, rot, _pose = self._read_current_pose()
+ if ok:
+ return pos, rot
+
+ time.sleep(0.01)
+
+ def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
+ """
+ Reads the latest 6-DoF pose received from the Android device's WebXR session.
+
+ This method accesses the most recent pose data stored by the `_android_callback`. It uses a
+ thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is
+ then decomposed into position and rotation, and the configured camera offset is applied.
+
+ Returns:
+ A tuple containing:
+ - A boolean indicating if a valid pose was available.
+ - The 3D position as a NumPy array, or None if no pose has been received yet.
+ - The orientation as a `Rotation` object, or None if no pose has been received.
+ - The raw 4x4 pose matrix as received from the teleop stream.
+ """
+ with self._android_lock:
+ if self._latest_pose is None:
+ return False, None, None, None
+ p = self._latest_pose.copy()
+ pose = self._latest_pose
+ rot = Rotation.from_matrix(p[:3, :3])
+ pos = p[:3, 3] - rot.apply(self.config.camera_offset)
+ return True, pos, rot, pose
+
+ def _android_callback(self, pose: np.ndarray, message: dict) -> None:
+ """
+ Callback function to handle incoming data from the Android teleop stream.
+
+ This method is executed by the `teleop` package's subscriber thread whenever a new
+ pose and message are received from the WebXR session on the Android phone. It updates
+ the internal state (`_latest_pose` and `_latest_message`) with the new data.
+ A thread lock is used to ensure that these shared variables are updated atomically,
+ preventing race conditions with the main thread that reads them.
+
+ Args:
+ pose: A 4x4 NumPy array representing the phone's transformation matrix.
+ message: A dictionary containing additional data, such as button presses or touch events.
+ """
+ with self._android_lock:
+ self._latest_pose = pose
+ self._latest_message = message
+
+ def get_action(self) -> dict:
+ ok, raw_pos, raw_rot, pose = self._read_current_pose()
+ if not ok or not self.is_calibrated:
+ return {}
+
+ # Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
+ raw_inputs: dict[str, float | int | bool] = {}
+ msg = self._latest_message or {}
+ raw_inputs["move"] = bool(msg.get("move", False))
+ raw_inputs["scale"] = float(msg.get("scale", 1.0))
+ raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
+ raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
+
+ enable = bool(raw_inputs.get("move", False))
+
+ # Rising edge then re-capture calibration immediately from current raw pose
+ if enable and not self._enabled:
+ self._reapply_position_calibration(raw_pos)
+
+ # Apply calibration
+ pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
+ rot_cal = self._calib_rot_inv * raw_rot
+
+ self._enabled = enable
+
+ return {
+ "phone.pos": pos_cal,
+ "phone.rot": rot_cal,
+ "phone.raw_inputs": raw_inputs,
+ "phone.enabled": self._enabled,
+ }
+
+ def disconnect(self) -> None:
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ self._teleop = None
+ if self._teleop_thread and self._teleop_thread.is_alive():
+ self._teleop_thread.join(timeout=1.0)
+ self._teleop_thread = None
+ self._latest_pose = None
+
+
+class Phone(Teleoperator):
+ """
+ Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
+ For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
+
+ Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
+ captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
+ """
+
+ config_class = PhoneConfig
+ name = "phone"
+
+ def __init__(self, config: PhoneConfig):
+ super().__init__(config)
+ self.config = config
+
+ self._phone_impl: Teleoperator
+
+ if self.config.phone_os == PhoneOS.IOS:
+ self._phone_impl = IOSPhone(config)
+ elif self.config.phone_os == PhoneOS.ANDROID:
+ self._phone_impl = AndroidPhone(config)
+ else:
+ raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
+
+ @property
+ def is_connected(self) -> bool:
+ return self._phone_impl.is_connected
+
+ def connect(self) -> None:
+ return self._phone_impl.connect()
+
+ def calibrate(self) -> None:
+ return self._phone_impl.calibrate()
+
+ @property
+ def is_calibrated(self) -> bool:
+ return self._phone_impl.is_calibrated
+
+ @property
+ def action_features(self) -> dict[str, type]:
+ return self._phone_impl.action_features
+
+ @property
+ def feedback_features(self) -> dict[str, type]:
+ return self._phone_impl.feedback_features
+
+ def configure(self) -> None:
+ return self._phone_impl.configure()
+
+ def get_action(self) -> dict:
+ return self._phone_impl.get_action()
+
+ def send_feedback(self, feedback: dict[str, float]) -> None:
+ return self._phone_impl.send_feedback(feedback)
+
+ def disconnect(self) -> None:
+ return self._phone_impl.disconnect()
diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py
new file mode 100644
index 000000000..a07a4a6cd
--- /dev/null
+++ b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
+from .reachy2_teleoperator import (
+ REACHY2_ANTENNAS_JOINTS,
+ REACHY2_L_ARM_JOINTS,
+ REACHY2_NECK_JOINTS,
+ REACHY2_R_ARM_JOINTS,
+ REACHY2_VEL,
+ Reachy2Teleoperator,
+)
diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/config_reachy2_teleoperator.py b/src/lerobot/teleoperators/reachy2_teleoperator/config_reachy2_teleoperator.py
new file mode 100644
index 000000000..4e615d363
--- /dev/null
+++ b/src/lerobot/teleoperators/reachy2_teleoperator/config_reachy2_teleoperator.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+
+from ..config import TeleoperatorConfig
+
+
+@TeleoperatorConfig.register_subclass("reachy2_teleoperator")
+@dataclass
+class Reachy2TeleoperatorConfig(TeleoperatorConfig):
+ # IP address of the Reachy 2 robot used as teleoperator
+ ip_address: str | None = "localhost"
+
+ # Whether to use the present position of the joints as actions
+ # if False, the goal position of the joints will be used
+ use_present_position: bool = False
+
+ # Which parts of the robot to use
+ with_mobile_base: bool = True
+ with_l_arm: bool = True
+ with_r_arm: bool = True
+ with_neck: bool = True
+ with_antennas: bool = True
+
+ def __post_init__(self):
+ if not (
+ self.with_mobile_base
+ or self.with_l_arm
+ or self.with_r_arm
+ or self.with_neck
+ or self.with_antennas
+ ):
+ raise ValueError(
+ "No Reachy2Teleoperator part used.\n"
+ "At least one part of the robot must be set to True "
+ "(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
+ )
diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py
new file mode 100644
index 000000000..5a427dd71
--- /dev/null
+++ b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py
@@ -0,0 +1,164 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+
+from reachy2_sdk import ReachySDK
+
+from ..teleoperator import Teleoperator
+from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
+
+logger = logging.getLogger(__name__)
+
+# {lerobot_keys: reachy2_sdk_keys}
+REACHY2_NECK_JOINTS = {
+ "neck_yaw.pos": "head.neck.yaw",
+ "neck_pitch.pos": "head.neck.pitch",
+ "neck_roll.pos": "head.neck.roll",
+}
+
+REACHY2_ANTENNAS_JOINTS = {
+ "l_antenna.pos": "head.l_antenna",
+ "r_antenna.pos": "head.r_antenna",
+}
+
+REACHY2_R_ARM_JOINTS = {
+ "r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
+ "r_shoulder_roll.pos": "r_arm.shoulder.roll",
+ "r_elbow_yaw.pos": "r_arm.elbow.yaw",
+ "r_elbow_pitch.pos": "r_arm.elbow.pitch",
+ "r_wrist_roll.pos": "r_arm.wrist.roll",
+ "r_wrist_pitch.pos": "r_arm.wrist.pitch",
+ "r_wrist_yaw.pos": "r_arm.wrist.yaw",
+ "r_gripper.pos": "r_arm.gripper",
+}
+
+REACHY2_L_ARM_JOINTS = {
+ "l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
+ "l_shoulder_roll.pos": "l_arm.shoulder.roll",
+ "l_elbow_yaw.pos": "l_arm.elbow.yaw",
+ "l_elbow_pitch.pos": "l_arm.elbow.pitch",
+ "l_wrist_roll.pos": "l_arm.wrist.roll",
+ "l_wrist_pitch.pos": "l_arm.wrist.pitch",
+ "l_wrist_yaw.pos": "l_arm.wrist.yaw",
+ "l_gripper.pos": "l_arm.gripper",
+}
+
+REACHY2_VEL = {
+ "mobile_base.vx": "vx",
+ "mobile_base.vy": "vy",
+ "mobile_base.vtheta": "vtheta",
+}
+
+
+class Reachy2Teleoperator(Teleoperator):
+ """
+ [Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
+ """
+
+ config_class = Reachy2TeleoperatorConfig
+ name = "reachy2_specific"
+
+ def __init__(self, config: Reachy2TeleoperatorConfig):
+ super().__init__(config)
+ self.config = config
+ self.reachy: None | ReachySDK = None
+
+ self.joints_dict: dict[str, str] = self._generate_joints_dict()
+
+ def _generate_joints_dict(self) -> dict[str, str]:
+ joints = {}
+ if self.config.with_neck:
+ joints.update(REACHY2_NECK_JOINTS)
+ if self.config.with_l_arm:
+ joints.update(REACHY2_L_ARM_JOINTS)
+ if self.config.with_r_arm:
+ joints.update(REACHY2_R_ARM_JOINTS)
+ if self.config.with_antennas:
+ joints.update(REACHY2_ANTENNAS_JOINTS)
+ return joints
+
+ @property
+ def action_features(self) -> dict[str, type]:
+ if self.config.with_mobile_base:
+ return {
+ **dict.fromkeys(
+ self.joints_dict.keys(),
+ float,
+ ),
+ **dict.fromkeys(
+ REACHY2_VEL.keys(),
+ float,
+ ),
+ }
+ else:
+ return dict.fromkeys(self.joints_dict.keys(), float)
+
+ @property
+ def feedback_features(self) -> dict[str, type]:
+ return {}
+
+ @property
+ def is_connected(self) -> bool:
+ return self.reachy.is_connected() if self.reachy is not None else False
+
+ def connect(self, calibrate: bool = True) -> None:
+ self.reachy = ReachySDK(self.config.ip_address)
+ if not self.is_connected:
+ raise ConnectionError()
+ logger.info(f"{self} connected.")
+
+ @property
+ def is_calibrated(self) -> bool:
+ return True
+
+ def calibrate(self) -> None:
+ pass
+
+ def configure(self) -> None:
+ pass
+
+ def get_action(self) -> dict[str, float]:
+ start = time.perf_counter()
+
+ if self.reachy and self.is_connected:
+ if self.config.use_present_position:
+ joint_action = {
+ k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
+ }
+ else:
+ joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
+
+ if not self.config.with_mobile_base:
+ dt_ms = (time.perf_counter() - start) * 1e3
+ logger.debug(f"{self} read action: {dt_ms:.1f}ms")
+ return joint_action
+
+ if self.config.use_present_position:
+ vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
+ else:
+ vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
+ dt_ms = (time.perf_counter() - start) * 1e3
+ logger.debug(f"{self} read action: {dt_ms:.1f}ms")
+ return {**joint_action, **vel_action}
+
+ def send_feedback(self, feedback: dict[str, float]) -> None:
+ raise NotImplementedError
+
+ def disconnect(self) -> None:
+ if self.reachy and self.is_connected:
+ self.reachy.disconnect()
diff --git a/src/lerobot/teleoperators/so100_leader/__init__.py b/src/lerobot/teleoperators/so100_leader/__init__.py
index 63c877e60..747416be2 100644
--- a/src/lerobot/teleoperators/so100_leader/__init__.py
+++ b/src/lerobot/teleoperators/so100_leader/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_so100_leader import SO100LeaderConfig
from .so100_leader import SO100Leader
diff --git a/src/lerobot/teleoperators/so100_leader/so100_leader.py b/src/lerobot/teleoperators/so100_leader/so100_leader.py
index 18dad44d4..edcfe53e6 100644
--- a/src/lerobot/teleoperators/so100_leader/so100_leader.py
+++ b/src/lerobot/teleoperators/so100_leader/so100_leader.py
@@ -17,12 +17,12 @@
import logging
import time
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_so100_leader import SO100LeaderConfig
@@ -72,6 +72,9 @@ class SO100Leader(Teleoperator):
self.bus.connect()
if not self.is_calibrated and calibrate:
+ logger.info(
+ "Mismatch between calibration values in the motor and the calibration file or no calibration file found"
+ )
self.calibrate()
self.configure()
@@ -82,6 +85,16 @@ class SO100Leader(Teleoperator):
return self.bus.is_calibrated
def calibrate(self) -> None:
+ if self.calibration:
+ # Calibration file exists, ask user whether to use it or run new calibration
+ user_input = input(
+ f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
+ )
+ if user_input.strip().lower() != "c":
+ logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
+ self.bus.write_calibration(self.calibration)
+ return
+
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
diff --git a/src/lerobot/teleoperators/so101_leader/__init__.py b/src/lerobot/teleoperators/so101_leader/__init__.py
index 1f45170e9..11e277c91 100644
--- a/src/lerobot/teleoperators/so101_leader/__init__.py
+++ b/src/lerobot/teleoperators/so101_leader/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_so101_leader import SO101LeaderConfig
from .so101_leader import SO101Leader
diff --git a/src/lerobot/teleoperators/so101_leader/so101_leader.py b/src/lerobot/teleoperators/so101_leader/so101_leader.py
index 2ce28d2e4..be804bf70 100644
--- a/src/lerobot/teleoperators/so101_leader/so101_leader.py
+++ b/src/lerobot/teleoperators/so101_leader/so101_leader.py
@@ -17,12 +17,12 @@
import logging
import time
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_so101_leader import SO101LeaderConfig
@@ -73,6 +73,9 @@ class SO101Leader(Teleoperator):
self.bus.connect()
if not self.is_calibrated and calibrate:
+ logger.info(
+ "Mismatch between calibration values in the motor and the calibration file or no calibration file found"
+ )
self.calibrate()
self.configure()
@@ -83,6 +86,16 @@ class SO101Leader(Teleoperator):
return self.bus.is_calibrated
def calibrate(self) -> None:
+ if self.calibration:
+ # Calibration file exists, ask user whether to use it or run new calibration
+ user_input = input(
+ f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
+ )
+ if user_input.strip().lower() != "c":
+ logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
+ self.bus.write_calibration(self.calibration)
+ return
+
logger.info(f"\nRunning calibration of {self}")
self.bus.disable_torque()
for motor in self.bus.motors:
diff --git a/src/lerobot/teleoperators/stretch3_gamepad/__init__.py b/src/lerobot/teleoperators/stretch3_gamepad/__init__.py
index ac45b6dd4..fa5a19974 100644
--- a/src/lerobot/teleoperators/stretch3_gamepad/__init__.py
+++ b/src/lerobot/teleoperators/stretch3_gamepad/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .configuration_stretch3 import Stretch3GamePadConfig
from .stretch3_gamepad import Stretch3GamePad
diff --git a/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py b/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py
index 507a21589..3af0b5be1 100644
--- a/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py
+++ b/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py
@@ -22,4 +22,4 @@ from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("stretch3")
@dataclass
class Stretch3GamePadConfig(TeleoperatorConfig):
- mock: bool = False
+ """Stretch3GamePadConfig"""
diff --git a/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py b/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py
index bdcb57d40..09fdfadd7 100644
--- a/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py
+++ b/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py
@@ -20,7 +20,7 @@ import numpy as np
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot_params import RobotParams
-from lerobot.errors import DeviceAlreadyConnectedError
+from lerobot.utils.errors import DeviceAlreadyConnectedError
from ..teleoperator import Teleoperator
from .configuration_stretch3 import Stretch3GamePadConfig
@@ -112,10 +112,6 @@ class Stretch3GamePad(Teleoperator):
def send_feedback(self, feedback: np.ndarray) -> None:
pass
- def print_logs(self) -> None:
- pass
- # TODO(aliberts): move robot-specific logs logic here
-
def disconnect(self) -> None:
self.api.stop()
self.is_connected = False
diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py
index 49f259c17..95020a962 100644
--- a/src/lerobot/teleoperators/teleoperator.py
+++ b/src/lerobot/teleoperators/teleoperator.py
@@ -13,13 +13,14 @@
# limitations under the License.
import abc
+import builtins
from pathlib import Path
-from typing import Any, Type
+from typing import Any
import draccus
-from lerobot.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
from lerobot.motors.motors_bus import MotorCalibration
+from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
from .config import TeleoperatorConfig
@@ -37,7 +38,7 @@ class Teleoperator(abc.ABC):
"""
# Set these in ALL subclasses
- config_class: Type[TeleoperatorConfig]
+ config_class: builtins.type[TeleoperatorConfig]
name: str
def __init__(self, config: TeleoperatorConfig):
diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py
index b49addc15..ada7ee8a1 100644
--- a/src/lerobot/teleoperators/utils.py
+++ b/src/lerobot/teleoperators/utils.py
@@ -12,11 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from enum import Enum
+from typing import cast
+
+from lerobot.utils.import_utils import make_device_from_device_class
+
from .config import TeleoperatorConfig
from .teleoperator import Teleoperator
+class TeleopEvents(Enum):
+ """Shared constants for teleoperator events across teleoperators."""
+
+ SUCCESS = "success"
+ FAILURE = "failure"
+ RERECORD_EPISODE = "rerecord_episode"
+ IS_INTERVENTION = "is_intervention"
+ TERMINATE_EPISODE = "terminate_episode"
+
+
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
+ # TODO(Steven): Consider just using the make_device_from_device_class for all types
if config.type == "keyboard":
from .keyboard import KeyboardTeleop
@@ -53,5 +69,24 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
from .keyboard.teleop_keyboard import KeyboardEndEffectorTeleop
return KeyboardEndEffectorTeleop(config)
+ elif config.type == "homunculus_glove":
+ from .homunculus import HomunculusGlove
+
+ return HomunculusGlove(config)
+ elif config.type == "homunculus_arm":
+ from .homunculus import HomunculusArm
+
+ return HomunculusArm(config)
+ elif config.type == "bi_so100_leader":
+ from .bi_so100_leader import BiSO100Leader
+
+ return BiSO100Leader(config)
+ elif config.type == "reachy2_teleoperator":
+ from .reachy2_teleoperator import Reachy2Teleoperator
+
+ return Reachy2Teleoperator(config)
else:
- raise ValueError(config.type)
+ try:
+ return cast(Teleoperator, make_device_from_device_class(config))
+ except Exception as e:
+ raise ValueError(f"Error creating robot with config {config}: {e}") from e
diff --git a/src/lerobot/teleoperators/widowx/__init__.py b/src/lerobot/teleoperators/widowx/__init__.py
index 122ee3290..42e312f49 100644
--- a/src/lerobot/teleoperators/widowx/__init__.py
+++ b/src/lerobot/teleoperators/widowx/__init__.py
@@ -1,2 +1,18 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from .config_widowx import WidowXConfig
from .widowx import WidowX
diff --git a/src/lerobot/teleoperators/widowx/widowx.py b/src/lerobot/teleoperators/widowx/widowx.py
index 6becd767f..1a00bd4d2 100644
--- a/src/lerobot/teleoperators/widowx/widowx.py
+++ b/src/lerobot/teleoperators/widowx/widowx.py
@@ -17,13 +17,13 @@
import logging
import time
-from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
DynamixelMotorsBus,
OperatingMode,
)
+from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_widowx import WidowXConfig
diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md
index 64ad7196c..34af282b0 100644
--- a/src/lerobot/templates/lerobot_modelcard_template.md
+++ b/src/lerobot/templates/lerobot_modelcard_template.md
@@ -1,7 +1,8 @@
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
-{{ card_data }}
+# prettier-ignore
+{{card_data}}
---
# Model Card for {{ model_name | default("Model ID", true) }}
@@ -18,10 +19,28 @@
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
{% elif model_name == "vqbet" %}
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
-{% elif model_name == "pi0" %}
-[Pi0](https://huggingface.co/papers/2410.24164) is a generalist vision-language-action transformer that converts multimodal observations and text instructions into robot actions for zero-shot task transfer.
{% elif model_name == "pi0fast" %}
[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time.
+{% elif model_name == "pi0" %}
+**π₀ (Pi0)**
+
+π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
+
+**Model Overview**
+
+π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
+
+For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0).
+{% elif model_name == "pi05" %}
+**π₀.₅ (Pi05) Policy**
+
+π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
+
+**Model Overview**
+
+π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
+
+For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
{% elif model_name == "sac" %}
[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments.
{% elif model_name == "reward_classifier" %}
@@ -43,7 +62,7 @@ Below is the short version on how to train and run inference/eval:
### Train from scratch
```bash
-python -m lerobot.scripts.train \
+lerobot-train \
--dataset.repo_id=${HF_USER}/ \
--policy.type=act \
--output_dir=outputs/train/ \
@@ -53,12 +72,12 @@ python -m lerobot.scripts.train \
--wandb.enable=true
```
-*Writes checkpoints to `outputs/train//checkpoints/`.*
+_Writes checkpoints to `outputs/train//checkpoints/`._
### Evaluate the policy/run inference
```bash
-python -m lerobot.record \
+lerobot-record \
--robot.type=so100_follower \
--dataset.repo_id=/eval_ \
--policy.path=/ \
@@ -71,4 +90,4 @@ Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a
## Model Details
-* **License:** {{ license | default("\[More Information Needed]", true) }}
+- **License:** {{ license | default("\[More Information Needed]", true) }}
diff --git a/src/lerobot/templates/visualize_dataset_homepage.html b/src/lerobot/templates/visualize_dataset_homepage.html
deleted file mode 100644
index 19613afb5..000000000
--- a/src/lerobot/templates/visualize_dataset_homepage.html
+++ /dev/null
@@ -1,68 +0,0 @@
-
-
-
-
-
- Interactive Video Background Page
-
-
-
-
-